From 1f25c68cf03cdc3d2d03aa04b50efffbaa61a54f Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 7 Aug 2023 14:46:42 -0400 Subject: [PATCH] Add ParentJoin KNN support (#12434) A `join` within Lucene is built by adding child-docs and parent-docs in order. Since our vector field already supports sparse indexing, it should be able to support parent join indexing. However, when searching for the closest `k`, it is still the k nearest children vectors with no way to join back to the parent. This commit adds this ability through some significant changes: - New leaf reader function that allows a collector for knn results - The knn results can then utilize bit-sets to join back to the parent id This type of support is critical for nearest passage retrieval over larger documents. Generally, you want the top-k documents and knowledge of the nearest passages over each top-k document. Lucene's join functionality is a nice fit for this. This does not replace the need for multi-valued vectors, which is important for other ranking methods (e.g. colbert token embeddings). But, it could be used in the case when metadata about the passage embedding must be stored (e.g. the related passage). --- lucene/CHANGES.txt | 4 + .../word2vec/Word2VecSynonymProvider.java | 15 +- .../lucene90/Lucene90HnswVectorsReader.java | 29 +- .../lucene90/Lucene90OnHeapHnswGraph.java | 2 +- .../lucene91/Lucene91HnswVectorsReader.java | 46 +-- .../lucene92/Lucene92HnswVectorsReader.java | 46 +-- .../lucene92/OffHeapFloatVectorValues.java | 7 - .../lucene94/Lucene94HnswVectorsReader.java | 83 ++--- .../lucene94/OffHeapByteVectorValues.java | 7 - .../lucene94/OffHeapFloatVectorValues.java | 7 - .../lucene91/Lucene91HnswGraphBuilder.java | 16 +- .../SimpleTextKnnVectorsReader.java | 42 +-- .../lucene/codecs/KnnVectorsFormat.java | 10 +- .../lucene/codecs/KnnVectorsReader.java | 18 +- .../lucene95/Lucene95HnswVectorsReader.java | 99 ++---- .../lucene95/OffHeapByteVectorValues.java | 7 - .../lucene95/OffHeapFloatVectorValues.java | 7 - .../perfield/PerFieldKnnVectorsFormat.java | 10 +- .../org/apache/lucene/index/CheckIndex.java | 16 +- .../org/apache/lucene/index/CodecReader.java | 20 +- .../lucene/index/DocValuesLeafReader.java | 10 +- .../lucene/index/ExitableDirectoryReader.java | 16 +- .../apache/lucene/index/FilterLeafReader.java | 14 +- .../org/apache/lucene/index/LeafReader.java | 95 +++++- .../lucene/index/ParallelLeafReader.java | 22 +- .../lucene/index/SlowCodecReaderWrapper.java | 10 +- .../lucene/index/SortingCodecReader.java | 7 +- .../lucene/search/AbstractKnnCollector.java | 69 ++++ .../lucene/search/KnnByteVectorQuery.java | 4 +- .../apache/lucene/search/KnnCollector.java | 88 ++++++ .../lucene/search/KnnFloatVectorQuery.java | 2 +- .../apache/lucene/search/KnnVectorQuery.java | 2 +- .../apache/lucene/search/TopKnnCollector.java | 70 +++++ .../lucene/util/hnsw/HnswGraphBuilder.java | 115 ++++++- .../lucene/util/hnsw/HnswGraphSearcher.java | 162 ++++++---- .../lucene/util/hnsw/IntToIntFunction.java | 23 ++ .../lucene/util/hnsw/NeighborQueue.java | 4 - .../hnsw/OrdinalTranslatedKnnCollector.java | 83 +++++ .../util/hnsw/RandomAccessVectorValues.java | 10 + .../index/TestSegmentToThreadMapping.java | 14 +- .../search/BaseKnnVectorQueryTestCase.java | 14 + .../lucene/search/TestTopKnnResults.java | 41 +++ .../lucene/util/hnsw/HnswGraphTestCase.java | 86 ++--- .../util/hnsw/TestHnswFloatVectorGraph.java | 15 +- .../highlight/TermVectorLeafReader.java | 14 +- .../ToParentBlockJoinByteKnnVectorQuery.java | 189 +++++++++++ .../ToParentBlockJoinFloatKnnVectorQuery.java | 191 ++++++++++++ .../search/join/ToParentJoinKnnCollector.java | 294 ++++++++++++++++++ ...ParentBlockJoinKnnVectorQueryTestCase.java | 281 +++++++++++++++++ .../lucene/search/join/TestBlockJoin.java | 61 ++++ ...TestParentBlockJoinByteKnnVectorQuery.java | 57 ++++ ...estParentBlockJoinFloatKnnVectorQuery.java | 97 ++++++ .../join/TestToParentJoinKnnResults.java | 128 ++++++++ .../lucene/index/memory/MemoryIndex.java | 14 +- .../asserting/AssertingKnnVectorsFormat.java | 16 +- .../tests/index/MergeReaderWrapper.java | 14 +- .../tests/index/MismatchedLeafReader.java | 14 + .../lucene/tests/search/QueryUtils.java | 14 +- 58 files changed, 2273 insertions(+), 578 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/AbstractKnnCollector.java create mode 100644 lucene/core/src/java/org/apache/lucene/search/KnnCollector.java create mode 100644 lucene/core/src/java/org/apache/lucene/search/TopKnnCollector.java create mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/IntToIntFunction.java create mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java create mode 100644 lucene/core/src/test/org/apache/lucene/search/TestTopKnnResults.java create mode 100644 lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinByteKnnVectorQuery.java create mode 100644 lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinFloatKnnVectorQuery.java create mode 100644 lucene/join/src/java/org/apache/lucene/search/join/ToParentJoinKnnCollector.java create mode 100644 lucene/join/src/test/org/apache/lucene/search/join/ParentBlockJoinKnnVectorQueryTestCase.java create mode 100644 lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java create mode 100644 lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java create mode 100644 lucene/join/src/test/org/apache/lucene/search/join/TestToParentJoinKnnResults.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index b5120f63aabb..36a435d58cf4 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -22,6 +22,10 @@ New Features * LUCENE-8183, GITHUB#9231: Added the abbility to get noSubMatches and noOverlappingMatches in HyphenationCompoundWordFilter (Martin Demberger, original from Rupert Westenthaler) +* GITHUB#12434: Add `KnnCollector` to `LeafReader` and `KnnVectorReader` so that custom collection of vector + search results can be provided. The first custom collector provides `ToParentBlockJoin[Float|Byte]KnnVectorQuery` + joining child vector documents with their parent documents. (Ben Trent) + Improvements --------------------- * GITHUB#12374: Add CachingLeafSlicesSupplier to compute the LeafSlices for concurrent segment search (Sorabh Hamirwasia) diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java index 3089f1587a4e..59e9ad0a96fc 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java @@ -25,10 +25,11 @@ import java.util.List; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.HnswGraphSearcher; -import org.apache.lucene.util.hnsw.NeighborQueue; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; /** @@ -73,7 +74,7 @@ public List getSynonyms( LinkedList result = new LinkedList<>(); float[] query = word2VecModel.vectorValue(term); if (query != null) { - NeighborQueue synonyms = + KnnCollector synonyms = HnswGraphSearcher.search( query, // The query vector is in the model. When looking for the top-k @@ -85,16 +86,16 @@ public List getSynonyms( hnswGraph, null, Integer.MAX_VALUE); + TopDocs topDocs = synonyms.topDocs(); - int size = synonyms.size(); - for (int i = 0; i < size; i++) { - float similarity = synonyms.topScore(); - int id = synonyms.pop(); + for (int i = 0; i < topDocs.scoreDocs.length; i++) { + float similarity = topDocs.scoreDocs[i].score; + int id = topDocs.scoreDocs[i].doc; BytesRef synonym = word2VecModel.termValue(id); // We remove the original query term if (!synonym.equals(term) && similarity >= minAcceptedSimilarity) { - result.addFirst(new TermAndBoost(synonym, similarity)); + result.addLast(new TermAndBoost(synonym, similarity)); } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index a22d34cfe231..50b201268aa1 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -34,9 +34,7 @@ import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IndexInput; @@ -237,48 +235,39 @@ public ByteVectorValues getByteVectorValues(String field) { } @Override - public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); if (fieldEntry.size() == 0) { - return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + return; } - // bound k by total number of vectors to prevent oversizing data structures - k = Math.min(k, fieldEntry.size()); - OffHeapFloatVectorValues vectorValues = getOffHeapVectorValues(fieldEntry); // use a seed that is fixed for the index so we get reproducible results for the same query final SplittableRandom random = new SplittableRandom(checksumSeed); NeighborQueue results = Lucene90OnHeapHnswGraph.search( target, - k, - k, + knnCollector.k(), + knnCollector.k(), vectorValues, fieldEntry.similarityFunction, getGraphValues(fieldEntry), getAcceptOrds(acceptDocs, fieldEntry), - visitedLimit, + knnCollector.visitLimit(), random); - int i = 0; - ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)]; + knnCollector.incVisitedCount(results.visitedCount()); while (results.size() > 0) { int node = results.topNode(); float minSimilarity = results.topScore(); results.pop(); - scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], minSimilarity); + knnCollector.collect(node, minSimilarity); } - TotalHits.Relation relation = - results.incomplete() - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; - return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); } @Override - public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java index 52dae288caa9..a8bab2c14755 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java @@ -78,7 +78,7 @@ public static NeighborQueue search( VectorSimilarityFunction similarityFunction, HnswGraph graphValues, Bits acceptOrds, - int visitedLimit, + long visitedLimit, SplittableRandom random) throws IOException { int size = graphValues.size(); diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 462f42c7e520..8bff5a6c01a0 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -35,9 +35,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IndexInput; @@ -46,7 +44,6 @@ import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraphSearcher; -import org.apache.lucene.util.hnsw.NeighborQueue; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** @@ -229,47 +226,28 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); if (fieldEntry.size() == 0) { - return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + return; } - // bound k by total number of vectors to prevent oversizing data structures - k = Math.min(k, fieldEntry.size()); OffHeapFloatVectorValues vectorValues = getOffHeapVectorValues(fieldEntry); - NeighborQueue results = - HnswGraphSearcher.search( - target, - k, - vectorValues, - VectorEncoding.FLOAT32, - fieldEntry.similarityFunction, - getGraph(fieldEntry), - getAcceptOrds(acceptDocs, fieldEntry), - visitedLimit); - - int i = 0; - ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)]; - while (results.size() > 0) { - int node = results.topNode(); - float minSimilarity = results.topScore(); - results.pop(); - scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc(node), minSimilarity); - } - - TotalHits.Relation relation = - results.incomplete() - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; - return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); + HnswGraphSearcher.search( + target, + knnCollector, + vectorValues, + VectorEncoding.FLOAT32, + fieldEntry.similarityFunction, + getGraph(fieldEntry), + getAcceptOrds(acceptDocs, fieldEntry)); } @Override - public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java index 0f6eddfffedd..505dc7640584 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java @@ -34,9 +34,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IndexInput; @@ -45,7 +43,6 @@ import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraphSearcher; -import org.apache.lucene.util.hnsw.NeighborQueue; import org.apache.lucene.util.packed.DirectMonotonicReader; /** @@ -225,47 +222,28 @@ public ByteVectorValues getByteVectorValues(String field) { } @Override - public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); if (fieldEntry.size() == 0) { - return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + return; } - // bound k by total number of vectors to prevent oversizing data structures - k = Math.min(k, fieldEntry.size()); OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData); - NeighborQueue results = - HnswGraphSearcher.search( - target, - k, - vectorValues, - VectorEncoding.FLOAT32, - fieldEntry.similarityFunction, - getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs), - visitedLimit); - - int i = 0; - ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)]; - while (results.size() > 0) { - int node = results.topNode(); - float minSimilarity = results.topScore(); - results.pop(); - scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), minSimilarity); - } - - TotalHits.Relation relation = - results.incomplete() - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; - return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); + HnswGraphSearcher.search( + target, + knnCollector, + vectorValues, + VectorEncoding.FLOAT32, + fieldEntry.similarityFunction, + getGraph(fieldEntry), + vectorValues.getAcceptOrds(acceptDocs)); } @Override - public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index 820b7a3e0843..142c3d20ac0f 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -61,8 +61,6 @@ public float[] vectorValue(int targetOrd) throws IOException { return value; } - public abstract int ordToDoc(int ord); - static OffHeapFloatVectorValues load( Lucene92HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException { if (fieldEntry.docsWithFieldOffset == -2) { @@ -118,11 +116,6 @@ public RandomAccessVectorValues copy() throws IOException { return new DenseOffHeapVectorValues(dimension, size, slice.clone()); } - @Override - public int ordToDoc(int ord) { - return ord; - } - @Override Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java index 781681715120..29bfd7b28e3a 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java @@ -34,9 +34,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IOContext; @@ -46,7 +44,6 @@ import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraphSearcher; -import org.apache.lucene.util.hnsw.NeighborQueue; import org.apache.lucene.util.packed.DirectMonotonicReader; /** @@ -269,83 +266,45 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { - return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + return; } - // bound k by total number of vectors to prevent oversizing data structures - k = Math.min(k, fieldEntry.size()); OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData); - NeighborQueue results = - HnswGraphSearcher.search( - target, - k, - vectorValues, - fieldEntry.vectorEncoding, - fieldEntry.similarityFunction, - getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs), - visitedLimit); - - int i = 0; - ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)]; - while (results.size() > 0) { - int node = results.topNode(); - float score = results.topScore(); - results.pop(); - scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score); - } - - TotalHits.Relation relation = - results.incomplete() - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; - return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); + HnswGraphSearcher.search( + target, + knnCollector, + vectorValues, + fieldEntry.vectorEncoding, + fieldEntry.similarityFunction, + getGraph(fieldEntry), + vectorValues.getAcceptOrds(acceptDocs)); } @Override - public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) { - return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + return; } - // bound k by total number of vectors to prevent oversizing data structures - k = Math.min(k, fieldEntry.size()); OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData); - NeighborQueue results = - HnswGraphSearcher.search( - target, - k, - vectorValues, - fieldEntry.vectorEncoding, - fieldEntry.similarityFunction, - getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs), - visitedLimit); - - int i = 0; - ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)]; - while (results.size() > 0) { - int node = results.topNode(); - float score = results.topScore(); - results.pop(); - scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score); - } - - TotalHits.Relation relation = - results.incomplete() - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; - return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); + HnswGraphSearcher.search( + target, + knnCollector, + vectorValues, + fieldEntry.vectorEncoding, + fieldEntry.similarityFunction, + getGraph(fieldEntry), + vectorValues.getAcceptOrds(acceptDocs)); } private HnswGraph getGraph(FieldEntry entry) throws IOException { diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index a94fe3f13a5b..682eb6616ed6 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -69,8 +69,6 @@ private void readValue(int targetOrd) throws IOException { slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize); } - public abstract int ordToDoc(int ord); - static OffHeapByteVectorValues load( Lucene94HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException { if (fieldEntry.docsWithFieldOffset == -2 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) { @@ -128,11 +126,6 @@ public RandomAccessVectorValues copy() throws IOException { return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); } - @Override - public int ordToDoc(int ord) { - return ord; - } - @Override Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index 33ba60c6b7bc..3dea01454bd6 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -61,8 +61,6 @@ public float[] vectorValue(int targetOrd) throws IOException { return value; } - public abstract int ordToDoc(int ord); - static OffHeapFloatVectorValues load( Lucene94HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException { if (fieldEntry.docsWithFieldOffset == -2) { @@ -120,11 +118,6 @@ public RandomAccessVectorValues copy() throws IOException { return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); } - @Override - public int ordToDoc(int ord) { - return ord; - } - @Override Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java index c82920181cc4..dbd64a162b04 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java @@ -29,6 +29,7 @@ import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.HnswGraphSearcher; import org.apache.lucene.util.hnsw.NeighborQueue; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; @@ -146,7 +147,7 @@ public void setInfoStream(InfoStream infoStream) { /** Inserts a doc with vector value to the graph */ void addGraphNode(int node, float[] value) throws IOException { - NeighborQueue candidates; + HnswGraphBuilder.GraphBuilderKnnCollector candidates; final int nodeLevel = getRandomGraphLevel(ml, random); int curMaxLevel = hnsw.numLevels() - 1; int[] eps = new int[] {hnsw.entryNode()}; @@ -159,12 +160,12 @@ void addGraphNode(int node, float[] value) throws IOException { // for levels > nodeLevel search with topk = 1 for (int level = curMaxLevel; level > nodeLevel; level--) { candidates = graphSearcher.searchLevel(value, 1, level, eps, vectorValues, hnsw); - eps = new int[] {candidates.pop()}; + eps = new int[] {candidates.popNode()}; } // for levels <= nodeLevel search with topk = beamWidth, and add connections for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) { candidates = graphSearcher.searchLevel(value, beamWidth, level, eps, vectorValues, hnsw); - eps = candidates.nodes(); + eps = candidates.popUntilNearestKNodes(); hnsw.addNode(level, node); addDiverseNeighbors(level, node, candidates); } @@ -188,7 +189,8 @@ private long printGraphBuildStatus(int node, long start, long t) { * work better if we keep the neighbor arrays sorted. Possibly we should switch back to a heap? * But first we should just see if sorting makes a significant difference. */ - private void addDiverseNeighbors(int level, int node, NeighborQueue candidates) + private void addDiverseNeighbors( + int level, int node, HnswGraphBuilder.GraphBuilderKnnCollector candidates) throws IOException { /* For each of the beamWidth nearest candidates (going from best to worst), select it only if it * is closer to target than it is to any of the already-selected neighbors (ie selected in this method, @@ -227,14 +229,14 @@ private void selectDiverse(Lucene91NeighborArray neighbors, Lucene91NeighborArra } } - private void popToScratch(NeighborQueue candidates) { + private void popToScratch(HnswGraphBuilder.GraphBuilderKnnCollector candidates) { scratch.clear(); int candidateCount = candidates.size(); // extract all the Neighbors from the queue into an array; these will now be // sorted from worst to best for (int i = 0; i < candidateCount; i++) { - float similarity = candidates.topScore(); - scratch.add(candidates.pop(), similarity); + float similarity = candidates.minCompetitiveSimilarity(); + scratch.add(candidates.popNode(), similarity); } } diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 3307ca3aaa05..7570445aa21c 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -37,10 +37,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.HitQueue; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.BufferedChecksumIndexInput; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.IOContext; @@ -182,7 +179,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FloatVectorValues values = getFloatVectorValues(field); if (target.length != values.dimension()) { @@ -194,36 +191,25 @@ public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int } FieldInfo info = readState.fieldInfos.fieldInfo(field); VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction(); - HitQueue topK = new HitQueue(k, false); - - int numVisited = 0; - TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO; - int doc; while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { if (acceptDocs != null && acceptDocs.get(doc) == false) { continue; } - if (numVisited >= visitedLimit) { - relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + if (knnCollector.earlyTerminated()) { break; } float[] vector = values.vectorValue(); float score = vectorSimilarity.compare(vector, target); - topK.insertWithOverflow(new ScoreDoc(doc, score)); - numVisited++; - } - ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()]; - for (int i = topScoreDocs.length - 1; i >= 0; i--) { - topScoreDocs[i] = topK.pop(); + knnCollector.collect(doc, score); + knnCollector.incVisitedCount(1); } - return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs); } @Override - public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { ByteVectorValues values = getByteVectorValues(field); if (target.length != values.dimension()) { @@ -235,10 +221,6 @@ public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int v } FieldInfo info = readState.fieldInfos.fieldInfo(field); VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction(); - HitQueue topK = new HitQueue(k, false); - - int numVisited = 0; - TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO; int doc; while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { @@ -246,21 +228,15 @@ public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int v continue; } - if (numVisited >= visitedLimit) { - relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + if (knnCollector.earlyTerminated()) { break; } byte[] vector = values.vectorValue(); float score = vectorSimilarity.compare(vector, target); - topK.insertWithOverflow(new ScoreDoc(doc, score)); - numVisited++; - } - ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()]; - for (int i = topScoreDocs.length - 1; i >= 0; i--) { - topScoreDocs[i] = topK.pop(); + knnCollector.collect(doc, score); + knnCollector.incVisitedCount(1); } - return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java index a17d844d8111..6d84323f3169 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java @@ -22,7 +22,7 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; import org.apache.lucene.util.NamedSPILoader; @@ -120,14 +120,14 @@ public ByteVectorValues getByteVectorValues(String field) { } @Override - public TopDocs search( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) { + public void search( + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) { throw new UnsupportedOperationException(); } @Override - public TopDocs search( - String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) { + public void search( + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java index d54a5f6d06e8..96530395861a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java @@ -22,6 +22,7 @@ import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; @@ -79,14 +80,12 @@ protected KnnVectorsReader() {} * * @param field the vector field to search * @param target the vector-valued query - * @param k the number of docs to return + * @param knnCollector a KnnResults collector and relevant settings for gathering vector results * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. - * @param visitedLimit the maximum number of nodes that the search is allowed to visit - * @return the k nearest neighbor documents, along with their (similarity-specific) scores. */ - public abstract TopDocs search( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException; + public abstract void search( + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; /** * Return the k nearest neighbor documents as determined by comparison of their vector values for @@ -109,14 +108,13 @@ public abstract TopDocs search( * * @param field the vector field to search * @param target the vector-valued query - * @param k the number of docs to return + * @param knnCollector a KnnResults collector and relevant settings for gathering vector results * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. - * @param visitedLimit the maximum number of nodes that the search is allowed to visit - * @return the k nearest neighbor documents, along with their (similarity-specific) scores. */ - public abstract TopDocs search( - String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException; + public abstract void search( + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; + /** * Returns an instance optimized for merging. This instance may only be consumed in the thread * that called {@link #getMergeInstance()}. diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java index 8e724b3f9c9c..00b19e251f29 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/Lucene95HnswVectorsReader.java @@ -34,9 +34,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IOContext; @@ -48,7 +46,6 @@ import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraphSearcher; -import org.apache.lucene.util.hnsw.NeighborQueue; import org.apache.lucene.util.packed.DirectMonotonicReader; /** @@ -274,89 +271,47 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); - if (fieldEntry.size() == 0) { - return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); - } - if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { - return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + if (fieldEntry.size() == 0 + || knnCollector.k() == 0 + || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { + return; } - // bound k by total number of vectors to prevent oversizing data structures - k = Math.min(k, fieldEntry.size()); OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData); - - NeighborQueue results = - HnswGraphSearcher.search( - target, - k, - vectorValues, - fieldEntry.vectorEncoding, - fieldEntry.similarityFunction, - getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs), - visitedLimit); - - int i = 0; - ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)]; - while (results.size() > 0) { - int node = results.topNode(); - float score = results.topScore(); - results.pop(); - scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score); - } - - TotalHits.Relation relation = - results.incomplete() - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; - return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); + HnswGraphSearcher.search( + target, + knnCollector, + vectorValues, + fieldEntry.vectorEncoding, + fieldEntry.similarityFunction, + getGraph(fieldEntry), + vectorValues.getAcceptOrds(acceptDocs)); } @Override - public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); - if (fieldEntry.size() == 0) { - return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); - } - if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) { - return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + if (fieldEntry.size() == 0 + || knnCollector.k() == 0 + || fieldEntry.vectorEncoding != VectorEncoding.BYTE) { + return; } - // bound k by total number of vectors to prevent oversizing data structures - k = Math.min(k, fieldEntry.size()); OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData); - - NeighborQueue results = - HnswGraphSearcher.search( - target, - k, - vectorValues, - fieldEntry.vectorEncoding, - fieldEntry.similarityFunction, - getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs), - visitedLimit); - - int i = 0; - ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)]; - while (results.size() > 0) { - int node = results.topNode(); - float score = results.topScore(); - results.pop(); - scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), score); - } - - TotalHits.Relation relation = - results.incomplete() - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; - return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs); + HnswGraphSearcher.search( + target, + knnCollector, + vectorValues, + fieldEntry.vectorEncoding, + fieldEntry.similarityFunction, + getGraph(fieldEntry), + vectorValues.getAcceptOrds(acceptDocs)); } /** Get knn graph values; used for testing */ diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 1308c35514ae..089148a33000 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -69,8 +69,6 @@ private void readValue(int targetOrd) throws IOException { slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize); } - public abstract int ordToDoc(int ord); - static OffHeapByteVectorValues load( Lucene95HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException { if (fieldEntry.docsWithFieldOffset == -2 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) { @@ -128,11 +126,6 @@ public RandomAccessVectorValues copy() throws IOException { return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); } - @Override - public int ordToDoc(int ord) { - return ord; - } - @Override Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 2ca67f70f16e..0f53e3f9c8b8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -62,8 +62,6 @@ public float[] vectorValue(int targetOrd) throws IOException { return value; } - public abstract int ordToDoc(int ord); - static OffHeapFloatVectorValues load( Lucene95HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException { if (fieldEntry.docsWithFieldOffset == -2 @@ -122,11 +120,6 @@ public RandomAccessVectorValues copy() throws IOException { return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); } - @Override - public int ordToDoc(int ord) { - return ord; - } - @Override Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java index 6d344629f1b1..d84e0f7b1c95 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java @@ -34,7 +34,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; -import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; @@ -271,15 +271,15 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - return fields.get(field).search(field, target, k, acceptDocs, visitedLimit); + fields.get(field).search(field, target, knnCollector, acceptDocs); } @Override - public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - return fields.get(field).search(field, target, k, acceptDocs, visitedLimit); + fields.get(field).search(field, target, knnCollector, acceptDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java index 3528bb41dd91..17bca7e4db97 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java +++ b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java @@ -58,10 +58,12 @@ import org.apache.lucene.index.PointValues.Relation; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.FieldExistsQuery; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.LeafFieldComparator; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.store.AlreadyClosedException; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; @@ -2654,10 +2656,9 @@ private static void checkFloatVectorValues( while (values.nextDoc() != NO_MORE_DOCS) { // search the first maxNumSearches vectors to exercise the graph if (values.docID() % everyNdoc == 0) { - TopDocs docs = - codecReader - .getVectorReader() - .search(fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE); + KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE); + codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null); + TopDocs docs = collector.topDocs(); if (docs.scoreDocs.length == 0) { throw new CheckIndexException( "Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors"); @@ -2699,10 +2700,9 @@ private static void checkByteVectorValues( while (values.nextDoc() != NO_MORE_DOCS) { // search the first maxNumSearches vectors to exercise the graph if (values.docID() % everyNdoc == 0) { - TopDocs docs = - codecReader - .getVectorReader() - .search(fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE); + KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE); + codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null); + TopDocs docs = collector.topDocs(); if (docs.scoreDocs.length == 0) { throw new CheckIndexException( "Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors"); diff --git a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java index 5d184aa4ee15..129302a6a069 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java @@ -25,7 +25,7 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; -import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; /** LeafReader implemented by codec APIs. */ @@ -257,29 +257,27 @@ public final ByteVectorValues getByteVectorValues(String field) throws IOExcepti } @Override - public final TopDocs searchNearestVectors( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { + public final void searchNearestVectors( + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { ensureOpen(); FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null || fi.getVectorDimension() == 0) { // Field does not exist or does not index vectors - return null; + return; } - - return getVectorReader().search(field, target, k, acceptDocs, visitedLimit); + getVectorReader().search(field, target, knnCollector, acceptDocs); } @Override - public final TopDocs searchNearestVectors( - String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { + public final void searchNearestVectors( + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { ensureOpen(); FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null || fi.getVectorDimension() == 0) { // Field does not exist or does not index vectors - return null; + return; } - - return getVectorReader().search(field, target, k, acceptDocs, visitedLimit); + getVectorReader().search(field, target, knnCollector, acceptDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java index 2fdd6ac7d054..8cb0034d82c8 100644 --- a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java @@ -18,7 +18,7 @@ package org.apache.lucene.index; import java.io.IOException; -import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; abstract class DocValuesLeafReader extends LeafReader { @@ -58,14 +58,14 @@ public final ByteVectorValues getByteVectorValues(String field) throws IOExcepti } @Override - public TopDocs searchNearestVectors( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { + public void searchNearestVectors( + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { throw new UnsupportedOperationException(); } @Override - public TopDocs searchNearestVectors( - String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { + public void searchNearestVectors( + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index 6c1f2c932361..3d9c91b88e83 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -21,7 +21,7 @@ import org.apache.lucene.index.FilterLeafReader.FilterTerms; import org.apache.lucene.index.FilterLeafReader.FilterTermsEnum; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.automaton.CompiledAutomaton; @@ -333,8 +333,9 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public TopDocs searchNearestVectors( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { + public void searchNearestVectors( + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + throws IOException { // when acceptDocs is null due to no doc deleted, we will instantiate a new one that would // match all docs to allow timeout checking. @@ -361,12 +362,13 @@ public int length() { } }; - return in.searchNearestVectors(field, target, k, timeoutCheckingAcceptDocs, visitedLimit); + in.searchNearestVectors(field, target, knnCollector, timeoutCheckingAcceptDocs); } @Override - public TopDocs searchNearestVectors( - String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { + public void searchNearestVectors( + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + throws IOException { // when acceptDocs is null due to no doc deleted, we will instantiate a new one that would // match all docs to allow timeout checking. final Bits updatedAcceptDocs = @@ -392,7 +394,7 @@ public int length() { } }; - return in.searchNearestVectors(field, target, k, timeoutCheckingAcceptDocs, visitedLimit); + in.searchNearestVectors(field, target, knnCollector, timeoutCheckingAcceptDocs); } private void checkAndThrowForSearchVectors() { diff --git a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java index e7e25adec778..a37ab0d508f2 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java @@ -18,7 +18,7 @@ import java.io.IOException; import java.util.Iterator; -import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.AttributeSource; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; @@ -357,15 +357,15 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public TopDocs searchNearestVectors( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { - return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit); + public void searchNearestVectors( + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override - public TopDocs searchNearestVectors( - String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { - return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit); + public void searchNearestVectors( + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java index 767874cc87d7..0c70681d43fb 100644 --- a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java @@ -17,8 +17,11 @@ package org.apache.lucene.index; import java.io.IOException; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopDocsCollector; +import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.Bits; @@ -240,8 +243,21 @@ public final PostingsEnum postings(Term term) throws IOException { * @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores. * @lucene.experimental */ - public abstract TopDocs searchNearestVectors( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException; + public final TopDocs searchNearestVectors( + String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { + FieldInfo fi = getFieldInfos().fieldInfo(field); + if (fi == null || fi.getVectorDimension() == 0) { + // The field does not exist or does not index vectors + return TopDocsCollector.EMPTY_TOPDOCS; + } + k = Math.min(k, getFloatVectorValues(fi.name).size()); + if (k == 0) { + return TopDocsCollector.EMPTY_TOPDOCS; + } + KnnCollector collector = new TopKnnCollector(k, visitedLimit); + searchNearestVectors(field, target, collector, acceptDocs); + return collector.topDocs(); + } /** * Return the k nearest neighbor documents as determined by comparison of their vector values for @@ -268,8 +284,79 @@ public abstract TopDocs searchNearestVectors( * @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores. * @lucene.experimental */ - public abstract TopDocs searchNearestVectors( - String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException; + public final TopDocs searchNearestVectors( + String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { + FieldInfo fi = getFieldInfos().fieldInfo(field); + if (fi == null || fi.getVectorDimension() == 0) { + // The field does not exist or does not index vectors + return TopDocsCollector.EMPTY_TOPDOCS; + } + k = Math.min(k, getByteVectorValues(fi.name).size()); + if (k == 0) { + return TopDocsCollector.EMPTY_TOPDOCS; + } + KnnCollector collector = new TopKnnCollector(k, visitedLimit); + searchNearestVectors(field, target, collector, acceptDocs); + return collector.topDocs(); + } + + /** + * Return the k nearest neighbor documents as determined by comparison of their vector values for + * this field, to the given vector, by the field's similarity function. The score of each document + * is derived from the vector similarity in a way that ensures scores are positive and that a + * larger score corresponds to a higher ranking. + * + *

The search is allowed to be approximate, meaning the results are not guaranteed to be the + * true k closest neighbors. For large values of k (for example when k is close to the total + * number of documents), the search may also retrieve fewer than k documents. + * + *

The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in + * order of their similarity to the query vector (decreasing scores). The {@link TotalHits} + * contains the number of documents visited during the search. If the search stopped early because + * it hit {@code visitedLimit}, it is indicated through the relation {@code + * TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}. + * + *

The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link + * FieldInfo}. The return value is never {@code null}. + * + * @param field the vector field to search + * @param target the vector-valued query + * @param knnCollector collector with settings for gathering the vector results. + * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} + * if they are all allowed to match. + * @lucene.experimental + */ + public abstract void searchNearestVectors( + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; + + /** + * Return the k nearest neighbor documents as determined by comparison of their vector values for + * this field, to the given vector, by the field's similarity function. The score of each document + * is derived from the vector similarity in a way that ensures scores are positive and that a + * larger score corresponds to a higher ranking. + * + *

The search is allowed to be approximate, meaning the results are not guaranteed to be the + * true k closest neighbors. For large values of k (for example when k is close to the total + * number of documents), the search may also retrieve fewer than k documents. + * + *

The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in + * order of their similarity to the query vector (decreasing scores). The {@link TotalHits} + * contains the number of documents visited during the search. If the search stopped early because + * it hit {@code visitedLimit}, it is indicated through the relation {@code + * TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}. + * + *

The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link + * FieldInfo}. The return value is never {@code null}. + * + * @param field the vector field to search + * @param target the vector-valued query + * @param knnCollector collector with settings for gathering the vector results. + * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} + * if they are all allowed to match. + * @lucene.experimental + */ + public abstract void searchNearestVectors( + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; /** * Get the {@link FieldInfos} describing all fields in this reader. diff --git a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java index 690ccb697e51..1dd105036773 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java @@ -26,8 +26,8 @@ import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Sort; -import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.Bits; import org.apache.lucene.util.Version; @@ -449,25 +449,25 @@ public ByteVectorValues getByteVectorValues(String fieldName) throws IOException } @Override - public TopDocs searchNearestVectors( - String fieldName, float[] target, int k, Bits acceptDocs, int visitedLimit) + public void searchNearestVectors( + String fieldName, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { ensureOpen(); LeafReader reader = fieldToReader.get(fieldName); - return reader == null - ? null - : reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit); + if (reader != null) { + reader.searchNearestVectors(fieldName, target, knnCollector, acceptDocs); + } } @Override - public TopDocs searchNearestVectors( - String fieldName, byte[] target, int k, Bits acceptDocs, int visitedLimit) + public void searchNearestVectors( + String fieldName, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { ensureOpen(); LeafReader reader = fieldToReader.get(fieldName); - return reader == null - ? null - : reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit); + if (reader != null) { + reader.searchNearestVectors(fieldName, target, knnCollector, acceptDocs); + } } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java index 3e45deaae9a4..81a3dcd6908d 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java @@ -28,7 +28,7 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; -import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; /** @@ -173,15 +173,15 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit); + reader.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override - public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit); + reader.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index 289461dbabf5..357e0e01d5dc 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -31,9 +31,9 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; -import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.IOSupplier; @@ -498,13 +498,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public TopDocs search( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) { + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) { throw new UnsupportedOperationException(); } @Override - public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) { + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnCollector.java new file mode 100644 index 000000000000..45586e29b053 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnCollector.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +/** + * AbstractKnnCollector is the default implementation for a knn collector used for gathering kNN + * results and providing topDocs from the gathered neighbors + */ +public abstract class AbstractKnnCollector implements KnnCollector { + + private long visitedCount; + private final long visitLimit; + private final int k; + + protected AbstractKnnCollector(int k, long visitLimit) { + this.visitLimit = visitLimit; + this.k = k; + } + + @Override + public final boolean earlyTerminated() { + return visitedCount >= visitLimit; + } + + @Override + public final void incVisitedCount(int count) { + assert count > 0; + this.visitedCount += count; + } + + @Override + public final long visitedCount() { + return visitedCount; + } + + @Override + public final long visitLimit() { + return visitLimit; + } + + @Override + public final int k() { + return k; + } + + @Override + public abstract boolean collect(int docId, float similarity); + + @Override + public abstract float minCompetitiveSimilarity(); + + @Override + public abstract TopDocs topDocs(); +} diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 10345cd7adf4..681ed2b2d9ff 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -28,8 +28,8 @@ import org.apache.lucene.util.Bits; /** - * Uses {@link KnnVectorsReader#search(String, byte[], int, Bits, int)} to perform nearest neighbour - * search. + * Uses {@link KnnVectorsReader#search(String, byte[], KnnCollector, Bits)} to perform nearest + * neighbour search. * *

This query also allows for performing a kNN search subject to a filter. In this case, it first * executes the filter for each leaf, then chooses a strategy dynamically: diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java new file mode 100644 index 000000000000..43bac9fbc309 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +/** + * KnnCollector is a knn collector used for gathering kNN results and providing topDocs from the + * gathered neighbors + * + * @lucene.experimental + */ +public interface KnnCollector { + + /** + * If search visits too many documents, the results collector will terminate early. Usually, this + * is due to some restricted filter on the document set. + * + *

When collection is earlyTerminated, the results are not a correct representation of k + * nearest neighbors. + * + * @return is the current result set marked as incomplete? + */ + boolean earlyTerminated(); + + /** + * @param count increments the visited vector count, must be greater than 0. + */ + void incVisitedCount(int count); + + /** + * @return the current visited vector count + */ + long visitedCount(); + + /** + * @return the visited vector limit + */ + long visitLimit(); + + /** + * @return the expected number of collected results + */ + int k(); + + /** + * Collect the provided docId and include in the result set. + * + * @param docId of the vector to collect + * @param similarity its calculated similarity + * @return true if the vector is collected + */ + boolean collect(int docId, float similarity); + + /** + * This method is utilized during search to ensure only competitive results are explored. + * + *

Consequently, if this results collector wants to collect `k` results, this should return + * {@link Float#NEGATIVE_INFINITY} when not full. + * + *

When full, the minimum score should be returned. + * + * @return the current minimum competitive similarity in the collection + */ + float minCompetitiveSimilarity(); + + /** + * This drains the collected nearest kNN results and returns them in a new {@link TopDocs} + * collection, ordered by score descending. NOTE: This is generally a destructive action and the + * collector should not be used after topDocs() is called. + * + * @return The collected top documents + */ + TopDocs topDocs(); +} diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 3036e7c45162..63b4f8821595 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -29,7 +29,7 @@ import org.apache.lucene.util.VectorUtil; /** - * Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest + * Uses {@link KnnVectorsReader#search(String, float[], KnnCollector, Bits)} to perform nearest * neighbour search. * *

This query also allows for performing a kNN search subject to a filter. In this case, it first diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java index 90c25f4b7e40..5016db61a5fe 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java @@ -20,7 +20,7 @@ import org.apache.lucene.util.Bits; /** - * Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest + * Uses {@link KnnVectorsReader#search(String, float[], KnnCollector, Bits)} to perform nearest * neighbour search. * *

This query also allows for performing a kNN search subject to a filter. In this case, it first diff --git a/lucene/core/src/java/org/apache/lucene/search/TopKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/TopKnnCollector.java new file mode 100644 index 000000000000..1f67704da4d1 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/TopKnnCollector.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import org.apache.lucene.util.hnsw.NeighborQueue; + +/** + * TopKnnCollector is a specific KnnCollector. A minHeap is used to keep track of the currently + * collected vectors allowing for efficient updates as better vectors are collected. + * + * @lucene.experimental + */ +public final class TopKnnCollector extends AbstractKnnCollector { + + private final NeighborQueue queue; + + /** + * @param k the number of neighbors to collect + * @param visitLimit how many vector nodes the results are allowed to visit + */ + public TopKnnCollector(int k, int visitLimit) { + super(k, visitLimit); + this.queue = new NeighborQueue(k, false); + } + + @Override + public boolean collect(int docId, float similarity) { + return queue.insertWithOverflow(docId, similarity); + } + + @Override + public float minCompetitiveSimilarity() { + return queue.size() >= k() ? queue.topScore() : Float.NEGATIVE_INFINITY; + } + + @Override + public TopDocs topDocs() { + assert queue.size() <= k() : "Tried to collect more results than the maximum number allowed"; + ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()]; + for (int i = 1; i <= scoreDocs.length; i++) { + scoreDocs[scoreDocs.length - i] = new ScoreDoc(queue.topNode(), queue.topScore()); + queue.pop(); + } + TotalHits.Relation relation = + earlyTerminated() + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs); + } + + @Override + public String toString() { + return "TopKnnCollector[k=" + k() + ", size=" + queue.size() + "]"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index 67f85a787ed9..1c825f6b5df2 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -30,6 +30,8 @@ import java.util.concurrent.TimeUnit; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.InfoStream; @@ -59,7 +61,6 @@ public final class HnswGraphBuilder { public static long randSeed = DEFAULT_RAND_SEED; private final int M; // max number of connections on upper layers - private final int beamWidth; private final double ml; private final NeighborArray scratch; @@ -68,8 +69,9 @@ public final class HnswGraphBuilder { private final RandomAccessVectorValues vectors; private final SplittableRandom random; private final HnswGraphSearcher graphSearcher; - private final NeighborQueue entryCandidates; // for upper levels of graph search - private final NeighborQueue beamCandidates; // for levels of graph where we add the node + private final GraphBuilderKnnCollector entryCandidates; // for upper levels of graph search + private final GraphBuilderKnnCollector + beamCandidates; // for levels of graph where we add the node final OnHeapHnswGraph hnsw; @@ -138,7 +140,6 @@ private HnswGraphBuilder( throw new IllegalArgumentException("beamWidth must be positive"); } this.M = M; - this.beamWidth = beamWidth; // normalization factor for level generation; currently not configurable this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M); this.random = new SplittableRandom(seed); @@ -151,8 +152,8 @@ private HnswGraphBuilder( new FixedBitSet(this.vectors.size())); // in scratch we store candidates in reverse order: worse candidates are first scratch = new NeighborArray(Math.max(beamWidth, M + 1), false); - entryCandidates = new NeighborQueue(1, false); - beamCandidates = new NeighborQueue(beamWidth, false); + entryCandidates = new GraphBuilderKnnCollector(1); + beamCandidates = new GraphBuilderKnnCollector(beamWidth); this.initializedNodes = new HashSet<>(); } @@ -283,20 +284,18 @@ public void addGraphNode(int node, T value) throws IOException { } // for levels > nodeLevel search with topk = 1 - NeighborQueue candidates = entryCandidates; + GraphBuilderKnnCollector candidates = entryCandidates; for (int level = curMaxLevel; level > nodeLevel; level--) { candidates.clear(); - graphSearcher.searchLevel( - candidates, value, 1, level, eps, vectors, hnsw, null, Integer.MAX_VALUE); - eps = new int[] {candidates.pop()}; + graphSearcher.searchLevel(candidates, value, level, eps, vectors, hnsw, null); + eps = new int[] {candidates.popNode()}; } // for levels <= nodeLevel search with topk = beamWidth, and add connections candidates = beamCandidates; for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) { candidates.clear(); - graphSearcher.searchLevel( - candidates, value, beamWidth, level, eps, vectors, hnsw, null, Integer.MAX_VALUE); - eps = candidates.nodes(); + graphSearcher.searchLevel(candidates, value, level, eps, vectors, hnsw, null); + eps = candidates.popUntilNearestKNodes(); hnsw.addNode(level, node); addDiverseNeighbors(level, node, candidates); } @@ -319,7 +318,7 @@ private long printGraphBuildStatus(int node, long start, long t) { return now; } - private void addDiverseNeighbors(int level, int node, NeighborQueue candidates) + private void addDiverseNeighbors(int level, int node, GraphBuilderKnnCollector candidates) throws IOException { /* For each of the beamWidth nearest candidates (going from best to worst), select it only if it * is closer to target than it is to any of the already-selected neighbors (ie selected in this method, @@ -360,14 +359,14 @@ private void selectAndLinkDiverse( } } - private void popToScratch(NeighborQueue candidates) { + private void popToScratch(GraphBuilderKnnCollector candidates) { scratch.clear(); int candidateCount = candidates.size(); // extract all the Neighbors from the queue into an array; these will now be // sorted from worst to best for (int i = 0; i < candidateCount; i++) { - float maxSimilarity = candidates.topScore(); - scratch.addInOrder(candidates.pop(), maxSimilarity); + float maxSimilarity = candidates.minimumScore(); + scratch.addInOrder(candidates.popNode(), maxSimilarity); } } @@ -550,4 +549,86 @@ private static int getRandomGraphLevel(double ml, SplittableRandom random) { } while (randDouble == 0.0); return ((int) (-log(randDouble) * ml)); } + + /** + * A restricted, specialized knnCollector that can be used when building a graph. + * + *

Does not support TopDocs + */ + public static final class GraphBuilderKnnCollector implements KnnCollector { + private final NeighborQueue queue; + private final int k; + private long visitedCount; + /** + * @param k the number of neighbors to collect + */ + public GraphBuilderKnnCollector(int k) { + this.queue = new NeighborQueue(k, false); + this.k = k; + } + + public int size() { + return queue.size(); + } + + public int popNode() { + return queue.pop(); + } + + public int[] popUntilNearestKNodes() { + while (size() > k()) { + queue.pop(); + } + return queue.nodes(); + } + + float minimumScore() { + return queue.topScore(); + } + + public void clear() { + this.queue.clear(); + this.visitedCount = 0; + } + + @Override + public boolean earlyTerminated() { + return false; + } + + @Override + public void incVisitedCount(int count) { + this.visitedCount += count; + } + + @Override + public long visitedCount() { + return visitedCount; + } + + @Override + public long visitLimit() { + return Long.MAX_VALUE; + } + + @Override + public int k() { + return k; + } + + @Override + public boolean collect(int docId, float similarity) { + return queue.insertWithOverflow(docId, similarity); + } + + @Override + public float minCompetitiveSimilarity() { + return queue.size() >= k() ? queue.topScore() : Float.NEGATIVE_INFINITY; + } + + @Override + public TopDocs topDocs() { + throw new IllegalArgumentException(); + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index ab792ca3bd05..28c7ca2b1638 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -22,6 +22,8 @@ import java.io.IOException; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; @@ -75,9 +77,9 @@ public HnswGraphSearcher( * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or * {@code null} if they are all allowed to match. * @param visitedLimit the maximum number of nodes that the search is allowed to visit - * @return a priority queue holding the closest neighbors found + * @return a set of collected vectors holding the nearest neighbors found */ - public static NeighborQueue search( + public static KnnCollector search( float[] query, int topK, RandomAccessVectorValues vectors, @@ -87,6 +89,32 @@ public static NeighborQueue search( Bits acceptOrds, int visitedLimit) throws IOException { + KnnCollector knnCollector = new TopKnnCollector(topK, visitedLimit); + search(query, knnCollector, vectors, vectorEncoding, similarityFunction, graph, acceptOrds); + return knnCollector; + } + + /** + * Searches HNSW graph for the nearest neighbors of a query vector. + * + * @param query search query vector + * @param knnCollector a collector of top knn results to be returned + * @param vectors the vector values + * @param similarityFunction the similarity function to compare vectors + * @param graph the graph values. May represent the entire graph, or a level in a hierarchical + * graph. + * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or + * {@code null} if they are all allowed to match. + */ + public static void search( + float[] query, + KnnCollector knnCollector, + RandomAccessVectorValues vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + HnswGraph graph, + Bits acceptOrds) + throws IOException { if (query.length != vectors.dimension()) { throw new IllegalArgumentException( "vector query dimension: " @@ -98,9 +126,9 @@ public static NeighborQueue search( new HnswGraphSearcher<>( vectorEncoding, similarityFunction, - new NeighborQueue(topK, true), + new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(vectors.size())); - return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); + search(query, knnCollector, vectors, graph, graphSearcher, acceptOrds); } /** @@ -108,7 +136,7 @@ public static NeighborQueue search( * {@link #search(float[], int, RandomAccessVectorValues, VectorEncoding, * VectorSimilarityFunction, HnswGraph, Bits, int)} */ - public static NeighborQueue search( + public static KnnCollector search( float[] query, int topK, RandomAccessVectorValues vectors, @@ -118,13 +146,15 @@ public static NeighborQueue search( Bits acceptOrds, int visitedLimit) throws IOException { + KnnCollector knnCollector = new TopKnnCollector(topK, visitedLimit); OnHeapHnswGraphSearcher graphSearcher = new OnHeapHnswGraphSearcher<>( vectorEncoding, similarityFunction, new NeighborQueue(topK, true), new SparseFixedBitSet(vectors.size())); - return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); + search(query, knnCollector, vectors, graph, graphSearcher, acceptOrds); + return knnCollector; } /** @@ -139,9 +169,9 @@ public static NeighborQueue search( * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or * {@code null} if they are all allowed to match. * @param visitedLimit the maximum number of nodes that the search is allowed to visit - * @return a priority queue holding the closest neighbors found + * @return a set of collected vectors holding the nearest neighbors found */ - public static NeighborQueue search( + public static KnnCollector search( byte[] query, int topK, RandomAccessVectorValues vectors, @@ -151,6 +181,32 @@ public static NeighborQueue search( Bits acceptOrds, int visitedLimit) throws IOException { + KnnCollector collector = new TopKnnCollector(topK, visitedLimit); + search(query, collector, vectors, vectorEncoding, similarityFunction, graph, acceptOrds); + return collector; + } + + /** + * Searches HNSW graph for the nearest neighbors of a query vector. + * + * @param query search query vector + * @param knnCollector a collector of top knn results to be returned + * @param vectors the vector values + * @param similarityFunction the similarity function to compare vectors + * @param graph the graph values. May represent the entire graph, or a level in a hierarchical + * graph. + * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or + * {@code null} if they are all allowed to match. + */ + public static void search( + byte[] query, + KnnCollector knnCollector, + RandomAccessVectorValues vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + HnswGraph graph, + Bits acceptOrds) + throws IOException { if (query.length != vectors.dimension()) { throw new IllegalArgumentException( "vector query dimension: " @@ -162,9 +218,9 @@ public static NeighborQueue search( new HnswGraphSearcher<>( vectorEncoding, similarityFunction, - new NeighborQueue(topK, true), + new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(vectors.size())); - return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); + search(query, knnCollector, vectors, graph, graphSearcher, acceptOrds); } /** @@ -172,7 +228,7 @@ public static NeighborQueue search( * {@link #search(byte[], int, RandomAccessVectorValues, VectorEncoding, VectorSimilarityFunction, * HnswGraph, Bits, int)} */ - public static NeighborQueue search( + public static KnnCollector search( byte[] query, int topK, RandomAccessVectorValues vectors, @@ -188,44 +244,34 @@ public static NeighborQueue search( similarityFunction, new NeighborQueue(topK, true), new SparseFixedBitSet(vectors.size())); - return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); + KnnCollector collector = new TopKnnCollector(topK, visitedLimit); + search(query, collector, vectors, graph, graphSearcher, acceptOrds); + return collector; } - private static NeighborQueue search( + private static void search( T query, - int topK, + KnnCollector knnCollector, RandomAccessVectorValues vectors, HnswGraph graph, HnswGraphSearcher graphSearcher, - Bits acceptOrds, - int visitedLimit) + Bits acceptOrds) throws IOException { int initialEp = graph.entryNode(); if (initialEp == -1) { - return new NeighborQueue(1, true); + return; } - int[] epAndVisited = graphSearcher.findBestEntryPoint(query, vectors, graph, visitedLimit); + int[] epAndVisited = + graphSearcher.findBestEntryPoint(query, vectors, graph, knnCollector.visitLimit()); int numVisited = epAndVisited[1]; int ep = epAndVisited[0]; if (ep == -1) { - NeighborQueue results = new NeighborQueue(1, false); - results.setVisitedCount(numVisited); - results.markIncomplete(); - return results; + knnCollector.incVisitedCount(numVisited); + return; } - NeighborQueue results = new NeighborQueue(topK, false); - graphSearcher.searchLevel( - results, - query, - topK, - 0, - new int[] {ep}, - vectors, - graph, - acceptOrds, - visitedLimit - numVisited); - results.setVisitedCount(results.visitedCount() + numVisited); - return results; + KnnCollector results = new OrdinalTranslatedKnnCollector(knnCollector, vectors::ordToDoc); + results.incVisitedCount(numVisited); + graphSearcher.searchLevel(results, query, 0, new int[] {ep}, vectors, graph, acceptOrds); } /** @@ -240,9 +286,9 @@ private static NeighborQueue search( * @param eps the entry points for search at this level expressed as level 0th ordinals * @param vectors vector values * @param graph the graph values - * @return a priority queue holding the closest neighbors found + * @return a set of collected vectors holding the nearest neighbors found */ - public NeighborQueue searchLevel( + public HnswGraphBuilder.GraphBuilderKnnCollector searchLevel( // Note: this is only public because Lucene91HnswGraphBuilder needs it T query, int topK, @@ -251,8 +297,9 @@ public NeighborQueue searchLevel( RandomAccessVectorValues vectors, HnswGraph graph) throws IOException { - NeighborQueue results = new NeighborQueue(topK, false); - searchLevel(results, query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE); + HnswGraphBuilder.GraphBuilderKnnCollector results = + new HnswGraphBuilder.GraphBuilderKnnCollector(topK); + searchLevel(results, query, level, eps, vectors, graph, null); return results; } @@ -268,7 +315,7 @@ public NeighborQueue searchLevel( * @throws IOException When accessing the vector fails */ private int[] findBestEntryPoint( - T query, RandomAccessVectorValues vectors, HnswGraph graph, int visitLimit) + T query, RandomAccessVectorValues vectors, HnswGraph graph, long visitLimit) throws IOException { int size = graph.size(); int visitedCount = 1; @@ -313,44 +360,36 @@ private int[] findBestEntryPoint( * last to be popped. */ void searchLevel( - NeighborQueue results, + KnnCollector results, T query, - int topK, int level, final int[] eps, RandomAccessVectorValues vectors, HnswGraph graph, - Bits acceptOrds, - int visitedLimit) + Bits acceptOrds) throws IOException { - assert results.isMinHeap(); int size = graph.size(); prepareScratchState(vectors.size()); - int numVisited = 0; for (int ep : eps) { if (visited.getAndSet(ep) == false) { - if (numVisited >= visitedLimit) { - results.markIncomplete(); + if (results.earlyTerminated()) { break; } float score = compare(query, vectors, ep); - numVisited++; + results.incVisitedCount(1); candidates.add(ep, score); if (acceptOrds == null || acceptOrds.get(ep)) { - results.add(ep, score); + results.collect(ep, score); } } } // A bound that holds the minimum similarity to the query vector that a candidate vector must // have to be considered. - float minAcceptedSimilarity = Float.NEGATIVE_INFINITY; - if (results.size() >= topK) { - minAcceptedSimilarity = results.topScore(); - } - while (candidates.size() > 0 && results.incomplete() == false) { + float minAcceptedSimilarity = results.minCompetitiveSimilarity(); + while (candidates.size() > 0 && results.earlyTerminated() == false) { // get the best candidate (closest or best scoring) float topCandidateSimilarity = candidates.topScore(); if (topCandidateSimilarity < minAcceptedSimilarity) { @@ -366,26 +405,21 @@ void searchLevel( continue; } - if (numVisited >= visitedLimit) { - results.markIncomplete(); + if (results.earlyTerminated()) { break; } float friendSimilarity = compare(query, vectors, friendOrd); - numVisited++; + results.incVisitedCount(1); if (friendSimilarity >= minAcceptedSimilarity) { candidates.add(friendOrd, friendSimilarity); if (acceptOrds == null || acceptOrds.get(friendOrd)) { - if (results.insertWithOverflow(friendOrd, friendSimilarity) && results.size() >= topK) { - minAcceptedSimilarity = results.topScore(); + if (results.collect(friendOrd, friendSimilarity)) { + minAcceptedSimilarity = results.minCompetitiveSimilarity(); } } } } } - while (results.size() > topK) { - results.pop(); - } - results.setVisitedCount(numVisited); } private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException { diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/IntToIntFunction.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/IntToIntFunction.java new file mode 100644 index 000000000000..e1ff7751c5ce --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/IntToIntFunction.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +/** Native int to int function */ +public interface IntToIntFunction { + int apply(int v); +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java index 4058c0743f2a..bcc43435872b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborQueue.java @@ -174,10 +174,6 @@ public void markIncomplete() { this.incomplete = true; } - boolean isMinHeap() { - return order == Order.MIN_HEAP; - } - @Override public String toString() { return "Neighbors[" + heap.size() + "]"; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java new file mode 100644 index 000000000000..e529b22feaff --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; + +/** + * Wraps a provided KnnCollector object, translating the provided vectorId ordinal to a documentId + */ +final class OrdinalTranslatedKnnCollector implements KnnCollector { + + private final KnnCollector in; + private final IntToIntFunction vectorOrdinalToDocId; + + OrdinalTranslatedKnnCollector(KnnCollector in, IntToIntFunction vectorOrdinalToDocId) { + this.in = in; + this.vectorOrdinalToDocId = vectorOrdinalToDocId; + } + + @Override + public boolean earlyTerminated() { + return in.earlyTerminated(); + } + + @Override + public void incVisitedCount(int count) { + in.incVisitedCount(count); + } + + @Override + public long visitedCount() { + return in.visitedCount(); + } + + @Override + public long visitLimit() { + return in.visitLimit(); + } + + @Override + public int k() { + return in.k(); + } + + @Override + public boolean collect(int vectorId, float similarity) { + return in.collect(vectorOrdinalToDocId.apply(vectorId), similarity); + } + + @Override + public float minCompetitiveSimilarity() { + return in.minCompetitiveSimilarity(); + } + + @Override + public TopDocs topDocs() { + TopDocs td = in.topDocs(); + return new TopDocs( + new TotalHits( + visitedCount(), + this.earlyTerminated() + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO), + td.scoreDocs); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java index 956749e678ec..d924a8f9cbd0 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java @@ -46,4 +46,14 @@ public interface RandomAccessVectorValues { * {@link RandomAccessVectorValues#vectorValue}. */ RandomAccessVectorValues copy() throws IOException; + + /** + * Translates vector ordinal to the correct document ID. By default, this is an identity function. + * + * @param ord the vector ordinal + * @return the document Id for that vector ordinal + */ + default int ordToDoc(int ord) { + return ord; + } } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java index c863ed16b077..adbac83734ee 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java @@ -26,9 +26,9 @@ import java.util.concurrent.TimeUnit; import org.apache.lucene.document.Document; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; @@ -122,16 +122,12 @@ public ByteVectorValues getByteVectorValues(String field) { } @Override - public TopDocs searchNearestVectors( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) { - return null; - } + public void searchNearestVectors( + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override - public TopDocs searchNearestVectors( - String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) { - return null; - } + public void searchNearestVectors( + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override protected void doClose() {} diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index 24de1a463c73..e7e9c7487d6e 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -152,6 +152,20 @@ public void testFindAll() throws IOException { } } + public void testFindFewer() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0, 0}, 2); + assertMatches(searcher, kvq, 2); + ScoreDoc[] scoreDocs = searcher.search(kvq, 3).scoreDocs; + assertEquals(scoreDocs.length, 2); + assertIdMatches(reader, "id2", scoreDocs[0]); + assertIdMatches(reader, "id0", scoreDocs[1]); + } + } + public void testSearchBoost() throws IOException { try (Directory indexStore = getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTopKnnResults.java b/lucene/core/src/test/org/apache/lucene/search/TestTopKnnResults.java new file mode 100644 index 000000000000..197d06e0ae14 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestTopKnnResults.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestTopKnnResults extends LuceneTestCase { + + public void testCollectAndProvideResults() { + TopKnnCollector results = new TopKnnCollector(5, Integer.MAX_VALUE); + int[] nodes = new int[] {4, 1, 5, 7, 8, 10, 2}; + float[] scores = new float[] {1f, 0.5f, 0.6f, 2f, 2f, 1.2f, 4f}; + for (int i = 0; i < nodes.length; i++) { + results.collect(nodes[i], scores[i]); + } + TopDocs topDocs = results.topDocs(); + int[] sortedNodes = new int[topDocs.scoreDocs.length]; + float[] sortedScores = new float[topDocs.scoreDocs.length]; + for (int i = 0; i < topDocs.scoreDocs.length; i++) { + sortedNodes[i] = topDocs.scoreDocs[i].doc; + sortedScores[i] = topDocs.scoreDocs[i].score; + } + assertArrayEquals(new int[] {2, 7, 8, 10, 4}, sortedNodes); + assertArrayEquals(new float[] {4f, 2f, 2f, 1.2f, 1f}, sortedScores, 0f); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 7bb8fcb7022e..8bcef207e315 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -61,6 +61,7 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.Sort; @@ -312,7 +313,7 @@ public void testAknnDiverse() throws IOException { vectors, getVectorEncoding(), similarityFunction, 10, 100, random().nextInt()); OnHeapHnswGraph hnsw = builder.build(vectors.copy()); // run some searches - final NeighborQueue nn; + final KnnCollector nn; switch (getVectorEncoding()) { case FLOAT32: nn = @@ -341,11 +342,11 @@ public void testAknnDiverse() throws IOException { default: throw new IllegalArgumentException("unexpected vector encoding: " + getVectorEncoding()); } - int[] nodes = nn.nodes(); - assertEquals("Number of found results is not equal to [10].", 10, nodes.length); + TopDocs topDocs = nn.topDocs(); + assertEquals("Number of found results is not equal to [10].", 10, topDocs.scoreDocs.length); int sum = 0; - for (int node : nodes) { - sum += node; + for (ScoreDoc node : topDocs.scoreDocs) { + sum += node.doc; } // We expect to get approximately 100% recall; // the lowest docIds are closest to zero; sum(0,9) = 45 @@ -372,7 +373,7 @@ public void testSearchWithAcceptOrds() throws IOException { OnHeapHnswGraph hnsw = builder.build(vectors.copy()); // the first 10 docs must not be deleted to ensure the expected recall Bits acceptOrds = createRandomAcceptOrds(10, nDoc); - final NeighborQueue nn; + final KnnCollector nn; switch (getVectorEncoding()) { case FLOAT32: nn = @@ -401,12 +402,12 @@ public void testSearchWithAcceptOrds() throws IOException { default: throw new IllegalArgumentException("unexpected vector encoding: " + getVectorEncoding()); } - int[] nodes = nn.nodes(); - assertEquals("Number of found results is not equal to [10].", 10, nodes.length); + TopDocs nodes = nn.topDocs(); + assertEquals("Number of found results is not equal to [10].", 10, nodes.scoreDocs.length); int sum = 0; - for (int node : nodes) { - assertTrue("the results include a deleted document: " + node, acceptOrds.get(node)); - sum += node; + for (ScoreDoc node : nodes.scoreDocs) { + assertTrue("the results include a deleted document: " + node, acceptOrds.get(node.doc)); + sum += node.doc; } // We expect to get approximately 100% recall; // the lowest docIds are closest to zero; sum(0,9) = 45 @@ -430,7 +431,7 @@ public void testSearchWithSelectiveAcceptOrds() throws IOException { // Check the search finds all accepted vectors int numAccepted = acceptOrds.cardinality(); - final NeighborQueue nn; + final KnnCollector nn; switch (getVectorEncoding()) { case FLOAT32: nn = @@ -459,10 +460,10 @@ public void testSearchWithSelectiveAcceptOrds() throws IOException { default: throw new IllegalArgumentException("unexpected vector encoding: " + getVectorEncoding()); } - int[] nodes = nn.nodes(); - assertEquals(numAccepted, nodes.length); - for (int node : nodes) { - assertTrue("the results include a deleted document: " + node, acceptOrds.get(node)); + TopDocs nodes = nn.topDocs(); + assertEquals(numAccepted, nodes.scoreDocs.length); + for (ScoreDoc node : nodes.scoreDocs) { + assertTrue("the results include a deleted document: " + node, acceptOrds.get(node.doc)); } } @@ -710,7 +711,7 @@ public void testVisitedLimit() throws IOException { int topK = 50; int visitedLimit = topK + random().nextInt(5); - final NeighborQueue nn; + final KnnCollector nn; switch (getVectorEncoding()) { case FLOAT32: nn = @@ -739,7 +740,7 @@ public void testVisitedLimit() throws IOException { default: throw new IllegalArgumentException("unexpected vector encoding: " + getVectorEncoding()); } - assertTrue(nn.incomplete()); + assertTrue(nn.earlyTerminated()); // The visited count shouldn't exceed the limit assertTrue(nn.visitedCount() <= visitedLimit); } @@ -939,7 +940,7 @@ public void testRandom() throws IOException { int totalMatches = 0; for (int i = 0; i < 100; i++) { - final NeighborQueue actual; + final KnnCollector actual; T query = randomVector(dim); switch (getVectorEncoding()) { case BYTE: @@ -969,9 +970,8 @@ public void testRandom() throws IOException { default: throw new IllegalArgumentException("unexpected vector encoding: " + getVectorEncoding()); } - while (actual.size() > topK) { - actual.pop(); - } + + TopDocs topDocs = actual.topDocs(); NeighborQueue expected = new NeighborQueue(topK, false); for (int j = 0; j < size; j++) { if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) { @@ -989,8 +989,11 @@ public void testRandom() throws IOException { } } } - assertEquals(topK, actual.size()); - totalMatches += computeOverlap(actual.nodes(), expected.nodes()); + int[] actualTopKDocs = new int[topK]; + for (int j = 0; j < topK; j++) { + actualTopKDocs[j] = topDocs.scoreDocs[j].doc; + } + totalMatches += computeOverlap(actualTopKDocs, expected.nodes()); } double overlap = totalMatches / (double) (100 * topK); System.out.println("overlap=" + overlap + " totalMatches=" + totalMatches); @@ -1004,7 +1007,6 @@ public void testOnHeapHnswGraphSearch() int size = atLeast(100); int dim = atLeast(10); AbstractMockVectorValues vectors = vectorValues(size, dim); - int topK = 5; HnswGraphBuilder builder = HnswGraphBuilder.create( vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong()); @@ -1012,9 +1014,9 @@ public void testOnHeapHnswGraphSearch() Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size); List queries = new ArrayList<>(); - List expects = new ArrayList<>(); + List expects = new ArrayList<>(); for (int i = 0; i < 100; i++) { - NeighborQueue expect = null; + KnnCollector expect = null; T query = randomVector(dim); queries.add(query); switch (getVectorEncoding()) { @@ -1042,21 +1044,17 @@ public void testOnHeapHnswGraphSearch() acceptOrds, Integer.MAX_VALUE); } - ; - while (expect.size() > topK) { - expect.pop(); - } expects.add(expect); } ExecutorService exec = Executors.newFixedThreadPool(4, new NamedThreadFactory("onHeapHnswSearch")); - List> futures = new ArrayList<>(); + List> futures = new ArrayList<>(); for (T query : queries) { futures.add( exec.submit( () -> { - NeighborQueue actual = null; + KnnCollector actual = null; try { switch (getVectorEncoding()) { @@ -1084,25 +1082,29 @@ public void testOnHeapHnswGraphSearch() acceptOrds, Integer.MAX_VALUE); } - ; } catch (IOException ioe) { throw new RuntimeException(ioe); } - while (actual.size() > topK) { - actual.pop(); - } return actual; })); } - List actuals = new ArrayList<>(); - for (Future future : futures) { + List actuals = new ArrayList<>(); + for (Future future : futures) { actuals.add(future.get(10, TimeUnit.SECONDS)); } exec.shutdownNow(); for (int i = 0; i < expects.size(); i++) { - NeighborQueue expect = expects.get(i); - NeighborQueue actual = actuals.get(i); - assertArrayEquals(expect.nodes(), actual.nodes()); + TopDocs expect = expects.get(i).topDocs(); + TopDocs actual = actuals.get(i).topDocs(); + int[] expectedDocs = new int[expect.scoreDocs.length]; + for (int j = 0; j < expect.scoreDocs.length; j++) { + expectedDocs[j] = expect.scoreDocs[j].doc; + } + int[] actualDocs = new int[actual.scoreDocs.length]; + for (int j = 0; j < actual.scoreDocs.length; j++) { + actualDocs[j] = actual.scoreDocs[j].doc; + } + assertArrayEquals(expectedDocs, actualDocs); } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java index 5dda5bf0a838..7ecb68295fb0 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java @@ -27,8 +27,11 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.FixedBitSet; import org.junit.Before; @@ -137,7 +140,7 @@ public void testSearchWithSkewedAcceptOrds() throws IOException { for (int i = 500; i < nDoc; i++) { acceptOrds.set(i); } - NeighborQueue nn = + KnnCollector nn = HnswGraphSearcher.search( getTargetVector(), 10, @@ -148,12 +151,12 @@ public void testSearchWithSkewedAcceptOrds() throws IOException { acceptOrds, Integer.MAX_VALUE); - int[] nodes = nn.nodes(); - assertEquals("Number of found results is not equal to [10].", 10, nodes.length); + TopDocs nodes = nn.topDocs(); + assertEquals("Number of found results is not equal to [10].", 10, nodes.scoreDocs.length); int sum = 0; - for (int node : nodes) { - assertTrue("the results include a deleted document: " + node, acceptOrds.get(node)); - sum += node; + for (ScoreDoc node : nodes.scoreDocs) { + assertTrue("the results include a deleted document: " + node, acceptOrds.get(node.doc)); + sum += node.doc; } // We still expect to get reasonable recall. The lowest non-skipped docIds // are closest to the query vector: sum(500,509) = 5045 diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java index 411f84412c6c..b5dcd325b53a 100644 --- a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java +++ b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java @@ -40,7 +40,7 @@ import org.apache.lucene.index.Terms; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; import org.apache.lucene.util.Version; @@ -171,16 +171,12 @@ public ByteVectorValues getByteVectorValues(String fieldName) { } @Override - public TopDocs searchNearestVectors( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) { - return null; - } + public void searchNearestVectors( + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override - public TopDocs searchNearestVectors( - String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) { - return null; - } + public void searchNearestVectors( + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override public void checkIntegrity() throws IOException {} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinByteKnnVectorQuery.java new file mode 100644 index 000000000000..13fb0749b116 --- /dev/null +++ b/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinByteKnnVectorQuery.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.join; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.HitQueue; +import org.apache.lucene.search.KnnByteVectorQuery; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopDocsCollector; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.Bits; + +/** kNN byte vector query that joins matching children vector documents with their parent doc id. */ +public class ToParentBlockJoinByteKnnVectorQuery extends KnnByteVectorQuery { + private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; + + private final BitSetProducer parentsFilter; + private final Query childFilter; + private final int k; + private final byte[] query; + + /** + * Create a ToParentBlockJoinByteVectorQuery. + * + * @param field the query field + * @param query the vector query + * @param childFilter the child filter + * @param k how many parent documents to return given the matching children + * @param parentsFilter Filter identifying the parent documents. + */ + public ToParentBlockJoinByteKnnVectorQuery( + String field, byte[] query, Query childFilter, int k, BitSetProducer parentsFilter) { + super(field, query, k, childFilter); + this.childFilter = childFilter; + this.parentsFilter = parentsFilter; + this.k = k; + this.query = query; + } + + @Override + protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) + throws IOException { + FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); + if (fi == null || fi.getVectorDimension() == 0) { + // The field does not exist or does not index vectors + return NO_RESULTS; + } + if (fi.getVectorEncoding() != VectorEncoding.BYTE) { + return null; + } + BitSet parentBitSet = parentsFilter.getBitSet(context); + ParentBlockJoinByteVectorScorer vectorScorer = + new ParentBlockJoinByteVectorScorer( + context.reader().getByteVectorValues(field), + acceptIterator, + parentBitSet, + query, + fi.getVectorSimilarityFunction()); + HitQueue queue = new HitQueue(k, true); + ScoreDoc topDoc = queue.top(); + int doc; + while ((doc = vectorScorer.nextParent()) != DocIdSetIterator.NO_MORE_DOCS) { + float score = vectorScorer.score(); + if (score > topDoc.score) { + topDoc.score = score; + topDoc.doc = doc; + topDoc = queue.updateTop(); + } + } + + // Remove any remaining sentinel values + while (queue.size() > 0 && queue.top().score < 0) { + queue.pop(); + } + + ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()]; + for (int i = topScoreDocs.length - 1; i >= 0; i--) { + topScoreDocs[i] = queue.pop(); + } + + TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO); + return new TopDocs(totalHits, topScoreDocs); + } + + @Override + protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit) + throws IOException { + BitSet parentBitSet = parentsFilter.getBitSet(context); + KnnCollector collector = new ToParentJoinKnnCollector(k, visitedLimit, parentBitSet); + context.reader().searchNearestVectors(field, query, collector, acceptDocs); + return collector.topDocs(); + } + + @Override + public String toString(String field) { + return getClass().getSimpleName() + ":" + this.field + "[" + query[0] + ",...][" + k + "]"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (!super.equals(o)) return false; + ToParentBlockJoinByteKnnVectorQuery that = (ToParentBlockJoinByteKnnVectorQuery) o; + return k == that.k + && Objects.equals(parentsFilter, that.parentsFilter) + && Objects.equals(childFilter, that.childFilter) + && Arrays.equals(query, that.query); + } + + @Override + public int hashCode() { + int result = Objects.hash(super.hashCode(), parentsFilter, childFilter, k); + result = 31 * result + Arrays.hashCode(query); + return result; + } + + private static class ParentBlockJoinByteVectorScorer { + private final byte[] query; + private final ByteVectorValues values; + private final VectorSimilarityFunction similarity; + private final DocIdSetIterator acceptedChildrenIterator; + private final BitSet parentBitSet; + private int currentParent = -1; + private float currentScore = Float.NEGATIVE_INFINITY; + + protected ParentBlockJoinByteVectorScorer( + ByteVectorValues values, + DocIdSetIterator acceptedChildrenIterator, + BitSet parentBitSet, + byte[] query, + VectorSimilarityFunction similarity) { + this.query = query; + this.values = values; + this.similarity = similarity; + this.acceptedChildrenIterator = acceptedChildrenIterator; + this.parentBitSet = parentBitSet; + } + + public int nextParent() throws IOException { + int nextChild = acceptedChildrenIterator.docID(); + if (nextChild == -1) { + nextChild = acceptedChildrenIterator.nextDoc(); + } + if (nextChild == DocIdSetIterator.NO_MORE_DOCS) { + currentParent = DocIdSetIterator.NO_MORE_DOCS; + return currentParent; + } + currentScore = Float.NEGATIVE_INFINITY; + currentParent = parentBitSet.nextSetBit(nextChild); + do { + values.advance(nextChild); + currentScore = Math.max(currentScore, similarity.compare(query, values.vectorValue())); + } while ((nextChild = acceptedChildrenIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS + && nextChild < currentParent); + return currentParent; + } + + public float score() throws IOException { + return currentScore; + } + } +} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinFloatKnnVectorQuery.java new file mode 100644 index 000000000000..2327b3d178a7 --- /dev/null +++ b/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinFloatKnnVectorQuery.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.join; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.HitQueue; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopDocsCollector; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.Bits; + +/** + * kNN float vector query that joins matching children vector documents with their parent doc id. + */ +public class ToParentBlockJoinFloatKnnVectorQuery extends KnnFloatVectorQuery { + private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; + + private final BitSetProducer parentsFilter; + private final Query childFilter; + private final int k; + private final float[] query; + + /** + * Create a ToParentBlockJoinFloatVectorQuery. + * + * @param field the query field + * @param query the vector query + * @param childFilter the child filter + * @param k how many parent documents to return given the matching children + * @param parentsFilter Filter identifying the parent documents. + */ + public ToParentBlockJoinFloatKnnVectorQuery( + String field, float[] query, Query childFilter, int k, BitSetProducer parentsFilter) { + super(field, query, k, childFilter); + this.childFilter = childFilter; + this.parentsFilter = parentsFilter; + this.k = k; + this.query = query; + } + + @Override + protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) + throws IOException { + FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); + if (fi == null || fi.getVectorDimension() == 0) { + // The field does not exist or does not index vectors + return NO_RESULTS; + } + if (fi.getVectorEncoding() != VectorEncoding.FLOAT32) { + return null; + } + BitSet parentBitSet = parentsFilter.getBitSet(context); + ParentBlockJoinFloatVectorScorer vectorScorer = + new ParentBlockJoinFloatVectorScorer( + context.reader().getFloatVectorValues(field), + acceptIterator, + parentBitSet, + query, + fi.getVectorSimilarityFunction()); + HitQueue queue = new HitQueue(k, true); + ScoreDoc topDoc = queue.top(); + int doc; + while ((doc = vectorScorer.nextParent()) != DocIdSetIterator.NO_MORE_DOCS) { + float score = vectorScorer.score(); + if (score > topDoc.score) { + topDoc.score = score; + topDoc.doc = doc; + topDoc = queue.updateTop(); + } + } + + // Remove any remaining sentinel values + while (queue.size() > 0 && queue.top().score < 0) { + queue.pop(); + } + + ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()]; + for (int i = topScoreDocs.length - 1; i >= 0; i--) { + topScoreDocs[i] = queue.pop(); + } + + TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO); + return new TopDocs(totalHits, topScoreDocs); + } + + @Override + protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit) + throws IOException { + BitSet parentBitSet = parentsFilter.getBitSet(context); + KnnCollector collector = new ToParentJoinKnnCollector(k, visitedLimit, parentBitSet); + context.reader().searchNearestVectors(field, query, collector, acceptDocs); + return collector.topDocs(); + } + + @Override + public String toString(String field) { + return getClass().getSimpleName() + ":" + this.field + "[" + query[0] + ",...][" + k + "]"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (!super.equals(o)) return false; + ToParentBlockJoinFloatKnnVectorQuery that = (ToParentBlockJoinFloatKnnVectorQuery) o; + return k == that.k + && Objects.equals(parentsFilter, that.parentsFilter) + && Objects.equals(childFilter, that.childFilter) + && Arrays.equals(query, that.query); + } + + @Override + public int hashCode() { + int result = Objects.hash(super.hashCode(), parentsFilter, childFilter, k); + result = 31 * result + Arrays.hashCode(query); + return result; + } + + private static class ParentBlockJoinFloatVectorScorer { + private final float[] query; + private final FloatVectorValues values; + private final VectorSimilarityFunction similarity; + private final DocIdSetIterator acceptedChildrenIterator; + private final BitSet parentBitSet; + private int currentParent = -1; + private float currentScore = Float.NEGATIVE_INFINITY; + + protected ParentBlockJoinFloatVectorScorer( + FloatVectorValues values, + DocIdSetIterator acceptedChildrenIterator, + BitSet parentBitSet, + float[] query, + VectorSimilarityFunction similarity) { + this.query = query; + this.values = values; + this.similarity = similarity; + this.acceptedChildrenIterator = acceptedChildrenIterator; + this.parentBitSet = parentBitSet; + } + + public int nextParent() throws IOException { + int nextChild = acceptedChildrenIterator.docID(); + if (nextChild == -1) { + nextChild = acceptedChildrenIterator.nextDoc(); + } + if (nextChild == DocIdSetIterator.NO_MORE_DOCS) { + currentParent = DocIdSetIterator.NO_MORE_DOCS; + return currentParent; + } + currentScore = Float.NEGATIVE_INFINITY; + currentParent = parentBitSet.nextSetBit(nextChild); + do { + values.advance(nextChild); + currentScore = Math.max(currentScore, similarity.compare(query, values.vectorValue())); + } while ((nextChild = acceptedChildrenIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS + && nextChild < currentParent); + return currentParent; + } + + public float score() throws IOException { + return currentScore; + } + } +} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/ToParentJoinKnnCollector.java b/lucene/join/src/java/org/apache/lucene/search/join/ToParentJoinKnnCollector.java new file mode 100644 index 000000000000..669cfabe2db4 --- /dev/null +++ b/lucene/join/src/java/org/apache/lucene/search/join/ToParentJoinKnnCollector.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search.join; + +import java.util.HashMap; +import java.util.Map; +import org.apache.lucene.search.AbstractKnnCollector; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.BitSet; + +/** parent joining knn collector, vector docIds are deduplicated according to the parent bit set. */ +class ToParentJoinKnnCollector extends AbstractKnnCollector { + + private final BitSet parentBitSet; + private final NodeIdCachingHeap heap; + + /** + * Create a new object for joining nearest child kNN documents with a parent bitset + * + * @param k The number of joined parent documents to collect + * @param visitLimit how many child vectors can be visited + * @param parentBitSet The leaf parent bitset + */ + public ToParentJoinKnnCollector(int k, int visitLimit, BitSet parentBitSet) { + super(k, visitLimit); + this.parentBitSet = parentBitSet; + this.heap = new NodeIdCachingHeap(k); + } + + /** + * If the heap is not full (size is less than the initialSize provided to the constructor), adds a + * new node-and-score element. If the heap is full, compares the score against the current top + * score, and replaces the top element if newScore is better than (greater than unless the heap is + * reversed), the current top score. + * + *

If docId's parent node has previously been collected and the provided nodeScore is less than + * the stored score it will not be collected. + * + * @param docId the neighbor docId + * @param nodeScore the score of the neighbor, relative to some other node + */ + @Override + public boolean collect(int docId, float nodeScore) { + assert !parentBitSet.get(docId); + int nodeId = parentBitSet.nextSetBit(docId); + return heap.insertWithOverflow(nodeId, nodeScore); + } + + @Override + public float minCompetitiveSimilarity() { + return heap.size >= k() ? heap.topScore() : Float.NEGATIVE_INFINITY; + } + + @Override + public String toString() { + return "ToParentJoinKnnCollector[k=" + k() + ", size=" + heap.size() + "]"; + } + + @Override + public TopDocs topDocs() { + assert heap.size() <= k() : "Tried to collect more results than the maximum number allowed"; + while (heap.size() > k()) { + heap.popToDrain(); + } + ScoreDoc[] scoreDocs = new ScoreDoc[heap.size()]; + for (int i = 1; i <= scoreDocs.length; i++) { + scoreDocs[scoreDocs.length - i] = new ScoreDoc(heap.topNode(), heap.topScore()); + heap.popToDrain(); + } + + TotalHits.Relation relation = + earlyTerminated() + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs); + } + + /** + * This is a minimum binary heap, inspired by {@link org.apache.lucene.util.LongHeap}. But instead + * of encoding and using `long` values. Node ids and scores are kept separate. Additionally, this + * prevents duplicate nodes from being added. + * + *

So, for every node added, we will update its score if the newly provided score is better. + * Every time we update a node's stored score, we ensure the heap's order. + */ + private static class NodeIdCachingHeap { + private final int maxSize; + private int[] heapNodes; + private float[] heapScores; + private int size = 0; + + // Used to keep track of nodeId -> positionInHeap. This way when new scores are added for a + // node, the heap can be + // updated efficiently. + private final Map nodeIdHeapIndex; + private boolean closed = false; + + public NodeIdCachingHeap(int maxSize) { + final int heapSize; + if (maxSize < 1 || maxSize >= ArrayUtil.MAX_ARRAY_LENGTH) { + // Throw exception to prevent confusing OOME: + throw new IllegalArgumentException( + "maxSize must be > 0 and < " + (ArrayUtil.MAX_ARRAY_LENGTH - 1) + "; got: " + maxSize); + } + // NOTE: we add +1 because all access to heap is 1-based not 0-based. heap[0] is unused. + heapSize = maxSize + 1; + this.maxSize = maxSize; + this.nodeIdHeapIndex = + new HashMap<>(maxSize < 2 ? maxSize + 1 : (int) (maxSize / 0.75 + 1.0)); + this.heapNodes = new int[heapSize]; + this.heapScores = new float[heapSize]; + } + + public final int topNode() { + return heapNodes[1]; + } + + public final float topScore() { + return heapScores[1]; + } + + private void pushIn(int nodeId, float score) { + size++; + if (size == heapNodes.length) { + heapNodes = ArrayUtil.grow(heapNodes, (size * 3 + 1) / 2); + heapScores = ArrayUtil.grow(heapScores, (size * 3 + 1) / 2); + } + heapNodes[size] = nodeId; + heapScores[size] = score; + upHeap(size); + } + + private void updateElement(int heapIndex, int nodeId, float score) { + int oldValue = heapNodes[heapIndex]; + assert oldValue == nodeId + : "attempted to update heap element value but with a different node id"; + float oldScore = heapScores[heapIndex]; + heapNodes[heapIndex] = nodeId; + heapScores[heapIndex] = score; + // Since we are a min heap, if the new value is less, we need to make sure to bubble it up + if (score < oldScore) { + upHeap(heapIndex); + } else { + downHeap(heapIndex); + } + } + + /** + * Adds a value to an heap in log(size) time. If the number of values would exceed the heap's + * maxSize, the least value is discarded. + * + *

If `node` already exists in the heap, this will return true if the stored score is updated + * OR the heap is not currently at the maxSize. + * + * @return whether the value was added or updated + */ + public boolean insertWithOverflow(int node, float score) { + if (closed) { + throw new IllegalStateException(); + } + Integer previousNodeIndex = nodeIdHeapIndex.get(node); + if (previousNodeIndex != null) { + if (heapScores[previousNodeIndex] < score) { + updateElement(previousNodeIndex, node, score); + return true; + } + return false; + } + if (size >= maxSize) { + if (score < heapScores[1] || (score == heapScores[1] && node > heapNodes[1])) { + return false; + } + updateTop(node, score); + return true; + } + pushIn(node, score); + return true; + } + + private void popToDrain() { + closed = true; + if (size > 0) { + heapNodes[1] = heapNodes[size]; // move last to first + heapScores[1] = heapScores[size]; // move last to first + size--; + downHeapWithoutCacheUpdate(1); // adjust heap + } else { + throw new IllegalStateException("The heap is empty"); + } + } + + private void updateTop(int nodeId, float score) { + nodeIdHeapIndex.remove(heapNodes[1]); + heapNodes[1] = nodeId; + heapScores[1] = score; + downHeap(1); + } + + /** Returns the number of elements currently stored in the PriorityQueue. */ + public final int size() { + return size; + } + + private boolean lessThan(int nodel, float scorel, int noder, float scorer) { + if (scorel < scorer) { + return true; + } + return scorel == scorer && nodel > noder; + } + + private void upHeap(int origPos) { + int i = origPos; + int bottomNode = heapNodes[i]; + float bottomScore = heapScores[i]; + int j = i >>> 1; + while (j > 0 && lessThan(bottomNode, bottomScore, heapNodes[j], heapScores[j])) { + heapNodes[i] = heapNodes[j]; + heapScores[i] = heapScores[j]; + nodeIdHeapIndex.put(heapNodes[i], i); + i = j; + j = j >>> 1; + } + nodeIdHeapIndex.put(bottomNode, i); + heapNodes[i] = bottomNode; + heapScores[i] = bottomScore; + } + + private int downHeap(int i) { + int node = heapNodes[i]; + float score = heapScores[i]; + int j = i << 1; // find smaller child + int k = j + 1; + if (k <= size && lessThan(heapNodes[k], heapScores[k], heapNodes[j], heapScores[j])) { + j = k; + } + while (j <= size && lessThan(heapNodes[j], heapScores[j], node, score)) { + heapNodes[i] = heapNodes[j]; + heapScores[i] = heapScores[j]; + nodeIdHeapIndex.put(heapNodes[i], i); + i = j; + j = i << 1; + k = j + 1; + if (k <= size && lessThan(heapNodes[k], heapScores[k], heapNodes[j], heapScores[j])) { + j = k; + } + } + nodeIdHeapIndex.put(node, i); + heapNodes[i] = node; // install saved value + heapScores[i] = score; // install saved value + return i; + } + + private int downHeapWithoutCacheUpdate(int i) { + int node = heapNodes[i]; + float score = heapScores[i]; + int j = i << 1; // find smaller child + int k = j + 1; + if (k <= size && lessThan(heapNodes[k], heapScores[k], heapNodes[j], heapScores[j])) { + j = k; + } + while (j <= size && lessThan(heapNodes[j], heapScores[j], node, score)) { + heapNodes[i] = heapNodes[j]; + heapScores[i] = heapScores[j]; + i = j; + j = i << 1; + k = j + 1; + if (k <= size && lessThan(heapNodes[k], heapScores[k], heapNodes[j], heapScores[j])) { + j = k; + } + } + heapNodes[i] = node; // install saved value + heapScores[i] = score; // install saved value + return i; + } + } +} diff --git a/lucene/join/src/test/org/apache/lucene/search/join/ParentBlockJoinKnnVectorQueryTestCase.java b/lucene/join/src/test/org/apache/lucene/search/join/ParentBlockJoinKnnVectorQueryTestCase.java new file mode 100644 index 000000000000..6402fe133e62 --- /dev/null +++ b/lucene/join/src/test/org/apache/lucene/search/join/ParentBlockJoinKnnVectorQueryTestCase.java @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search.join; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.LuceneTestCase; + +abstract class ParentBlockJoinKnnVectorQueryTestCase extends LuceneTestCase { + + static String encodeInts(int[] i) { + return Arrays.toString(i); + } + + static BitSetProducer parentFilter(IndexReader r) throws IOException { + // Create a filter that defines "parent" documents in the index + BitSetProducer parentsFilter = + new QueryBitSetProducer(new TermQuery(new Term("docType", "_parent"))); + CheckJoinIndex.check(r, parentsFilter); + return parentsFilter; + } + + Document makeParent(int[] children) { + Document parent = new Document(); + parent.add(newStringField("docType", "_parent", Field.Store.NO)); + parent.add(newStringField("id", encodeInts(children), Field.Store.YES)); + return parent; + } + + abstract Query getParentJoinKnnQuery( + String fieldName, float[] queryVector, Query childFilter, int k, BitSetProducer parentBitSet); + + public void testEmptyIndex() throws IOException { + try (Directory indexStore = getIndexStore("field"); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + Query kvq = + getParentJoinKnnQuery( + "field", + new float[] {1, 2}, + null, + 2, + new QueryBitSetProducer(new TermQuery(new Term("docType", "_parent")))); + assertMatches(searcher, kvq, 0); + Query q = searcher.rewrite(kvq); + assertTrue(q instanceof MatchNoDocsQuery); + } + } + + public void testFilterWithNoVectorMatches() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + Query filter = new TermQuery(new Term("other", "value")); + BitSetProducer parentFilter = parentFilter(reader); + Query kvq = getParentJoinKnnQuery("field", new float[] {1, 2}, filter, 2, parentFilter); + TopDocs topDocs = searcher.search(kvq, 3); + assertEquals(0, topDocs.totalHits.value); + } + } + + public void testScoringWithMultipleChildren() throws IOException { + try (Directory d = newDirectory()) { + try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) { + List toAdd = new ArrayList<>(); + for (int j = 1; j <= 5; j++) { + Document doc = new Document(); + doc.add(getKnnVectorField("field", new float[] {j, j})); + doc.add(newStringField("id", Integer.toString(j), Field.Store.YES)); + toAdd.add(doc); + } + toAdd.add(makeParent(new int[] {1, 2, 3, 4, 5})); + w.addDocuments(toAdd); + + toAdd = new ArrayList<>(); + for (int j = 7; j <= 11; j++) { + Document doc = new Document(); + doc.add(getKnnVectorField("field", new float[] {j, j})); + doc.add(newStringField("id", Integer.toString(j), Field.Store.YES)); + toAdd.add(doc); + } + toAdd.add(makeParent(new int[] {6, 7, 8, 9, 10})); + w.addDocuments(toAdd); + } + try (IndexReader reader = DirectoryReader.open(d)) { + assertEquals(1, reader.leaves().size()); + IndexSearcher searcher = new IndexSearcher(reader); + BitSetProducer parentFilter = parentFilter(searcher.getIndexReader()); + Query query = getParentJoinKnnQuery("field", new float[] {2, 2}, null, 3, parentFilter); + assertScorerResults( + searcher, + query, + new float[] {1f, 1f / 51f}, + new String[] { + encodeInts(new int[] {1, 2, 3, 4, 5}), encodeInts(new int[] {6, 7, 8, 9, 10}) + }); + + query = getParentJoinKnnQuery("field", new float[] {6, 6}, null, 3, parentFilter); + assertScorerResults( + searcher, + query, + new float[] {1f / 3f, 1f / 3f}, + new String[] { + encodeInts(new int[] {1, 2, 3, 4, 5}), encodeInts(new int[] {6, 7, 8, 9, 10}) + }); + query = + getParentJoinKnnQuery( + "field", new float[] {6, 6}, new MatchAllDocsQuery(), 20, parentFilter); + assertScorerResults( + searcher, + query, + new float[] {1f / 3f, 1f / 3f}, + new String[] { + encodeInts(new int[] {1, 2, 3, 4, 5}), encodeInts(new int[] {6, 7, 8, 9, 10}) + }); + + query = + getParentJoinKnnQuery( + "field", new float[] {6, 6}, new MatchAllDocsQuery(), 1, parentFilter); + assertScorerResults( + searcher, + query, + new float[] {1f / 3f}, + new String[] {encodeInts(new int[] {1, 2, 3, 4, 5})}); + } + } + } + + /** Test that when vectors are abnormally distributed among segments, we still find the top K */ + public void testSkewedIndex() throws IOException { + /* We have to choose the numbers carefully here so that some segment has more than the expected + * number of top K documents, but no more than K documents in total (otherwise we might occasionally + * randomly fail to find one). + */ + try (Directory d = newDirectory()) { + try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) { + int r = 0; + for (int i = 0; i < 5; i++) { + for (int j = 0; j < 5; j++) { + List toAdd = new ArrayList<>(); + Document doc = new Document(); + doc.add(getKnnVectorField("field", new float[] {r, r})); + doc.add(newStringField("id", Integer.toString(r), Field.Store.YES)); + toAdd.add(doc); + toAdd.add(makeParent(new int[] {r})); + w.addDocuments(toAdd); + ++r; + } + w.flush(); + } + } + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + TopDocs results = + searcher.search( + getParentJoinKnnQuery( + "field", new float[] {0, 0}, null, 8, parentFilter(searcher.getIndexReader())), + 10); + assertEquals(8, results.scoreDocs.length); + assertIdMatches(reader, "[0]", results.scoreDocs[0].doc); + assertIdMatches(reader, "[7]", results.scoreDocs[7].doc); + + // test some results in the middle of the sequence - also tests docid tiebreaking + results = + searcher.search( + getParentJoinKnnQuery( + "field", + new float[] {10, 10}, + null, + 8, + parentFilter(searcher.getIndexReader())), + 10); + assertEquals(8, results.scoreDocs.length); + assertIdMatches(reader, "[10]", results.scoreDocs[0].doc); + assertIdMatches(reader, "[6]", results.scoreDocs[7].doc); + } + } + } + + Directory getIndexStore(String field, float[]... contents) throws IOException { + Directory indexStore = newDirectory(); + RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore); + for (int i = 0; i < contents.length; ++i) { + List toAdd = new ArrayList<>(); + Document doc = new Document(); + doc.add(getKnnVectorField(field, contents[i])); + doc.add(newStringField("id", Integer.toString(i), Field.Store.YES)); + toAdd.add(doc); + toAdd.add(makeParent(new int[] {i})); + writer.addDocuments(toAdd); + } + // Add some documents without a vector + for (int i = 0; i < 5; i++) { + List toAdd = new ArrayList<>(); + Document doc = new Document(); + doc.add(new StringField("other", "value", Field.Store.NO)); + toAdd.add(doc); + toAdd.add(makeParent(new int[0])); + writer.addDocuments(toAdd); + } + writer.close(); + return indexStore; + } + + // @Override + abstract Field getKnnVectorField(String name, float[] vector); + + abstract Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction vectorSimilarityFunction); + + private void assertMatches(IndexSearcher searcher, Query q, int expectedMatches) + throws IOException { + ScoreDoc[] result = searcher.search(q, 1000).scoreDocs; + assertEquals(expectedMatches, result.length); + } + + void assertIdMatches(IndexReader reader, String expectedId, int docId) throws IOException { + String actualId = reader.storedFields().document(docId).get("id"); + assertEquals(expectedId, actualId); + } + + void assertScorerResults(IndexSearcher searcher, Query query, float[] scores, String[] ids) + throws IOException { + IndexReader reader = searcher.getIndexReader(); + Query rewritten = query.rewrite(searcher); + Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1); + Scorer scorer = weight.scorer(searcher.getIndexReader().leaves().get(0)); + // prior to advancing, score is undefined + assertEquals(-1, scorer.docID()); + expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score); + DocIdSetIterator it = scorer.iterator(); + for (int i = 0; i < scores.length; i++) { + int docId = it.nextDoc(); + assertNotEquals(NO_MORE_DOCS, docId); + assertEquals(scores[i], scorer.score(), 0.0001); + assertIdMatches(reader, ids[i], docId); + } + } +} diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoin.java b/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoin.java index ca4246196fce..8ebf7a1e2e97 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoin.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoin.java @@ -31,6 +31,7 @@ import org.apache.lucene.document.Field; import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.SortedDocValuesField; import org.apache.lucene.document.StoredField; @@ -47,6 +48,7 @@ import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.ReaderUtil; import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.*; import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.similarities.BasicStats; @@ -93,6 +95,19 @@ private Document makeQualification(String qualification, int year) { return job; } + private Document makeVector(String vectorField, float[] value) { + Document vectorDoc = new Document(); + vectorDoc.add(new KnnFloatVectorField(vectorField, value)); + return vectorDoc; + } + + private Document makeParent(String parentId) { + Document parent = new Document(); + parent.add(newStringField("docType", "_parent", Field.Store.NO)); + parent.add(newStringField("parent_id", parentId, Store.YES)); + return parent; + } + public void testEmptyChildFilter() throws Exception { final Directory dir = newDirectory(); final IndexWriterConfig config = new IndexWriterConfig(new MockAnalyzer(random())); @@ -229,6 +244,52 @@ public void testBQShouldJoinedChild() throws Exception { dir.close(); } + public void testSimpleKnn() throws Exception { + + final Directory dir = newDirectory(); + final RandomIndexWriter w = new RandomIndexWriter(random(), dir); + + final List docs = new ArrayList<>(); + + docs.add(makeVector("vector", new float[] {1f, 2f, 3f})); + docs.add(makeVector("vector", new float[] {3f, 3f, 3f})); + docs.add(makeParent("parent1")); + w.addDocuments(docs); + + docs.clear(); + docs.add(makeVector("vector", new float[] {0f, 0f, 1f})); + docs.add(makeVector("vector", new float[] {1f, 1f, 1f})); + docs.add(makeParent("parent2")); + w.addDocuments(docs); + + IndexReader r = w.getReader(); + w.close(); + IndexSearcher s = newSearcher(r, false); + + // Create a filter that defines "parent" documents in the index + BitSetProducer parentsFilter = + new QueryBitSetProducer(new TermQuery(new Term("docType", "_parent"))); + CheckJoinIndex.check(r, parentsFilter); + + ToParentBlockJoinFloatKnnVectorQuery childKnnJoin = + new ToParentBlockJoinFloatKnnVectorQuery( + "vector", new float[] {4f, 4f, 4f}, null, 3, parentsFilter); + + TopDocs topDocs = s.search(childKnnJoin, 5); + assertEquals(2, topDocs.totalHits.value); + Document parentDoc = s.storedFields().document(topDocs.scoreDocs[0].doc); + assertEquals("parent1", parentDoc.get("parent_id")); + assertEquals( + topDocs.scoreDocs[0].score, + VectorSimilarityFunction.EUCLIDEAN.compare( + new float[] {4f, 4f, 4f}, new float[] {3f, 3f, 3f}), + 1e-7); + parentDoc = s.storedFields().document(topDocs.scoreDocs[1].doc); + assertEquals("parent2", parentDoc.get("parent_id")); + r.close(); + dir.close(); + } + public void testSimple() throws Exception { final Directory dir = newDirectory(); diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java new file mode 100644 index 000000000000..6f5de4383a6e --- /dev/null +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search.join; + +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.Query; + +public class TestParentBlockJoinByteKnnVectorQuery extends ParentBlockJoinKnnVectorQueryTestCase { + + @Override + Query getParentJoinKnnQuery( + String fieldName, + float[] queryVector, + Query childFilter, + int k, + BitSetProducer parentBitSet) { + return new ToParentBlockJoinByteKnnVectorQuery( + fieldName, fromFloat(queryVector), childFilter, k, parentBitSet); + } + + @Override + Field getKnnVectorField(String name, float[] vector) { + return new KnnByteVectorField(name, fromFloat(vector)); + } + + @Override + Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction vectorSimilarityFunction) { + return new KnnByteVectorField(name, fromFloat(vector), vectorSimilarityFunction); + } + + private static byte[] fromFloat(float[] queryVector) { + byte[] query = new byte[queryVector.length]; + for (int i = 0; i < queryVector.length; i++) { + assert queryVector[i] == (byte) queryVector[i]; + query[i] = (byte) queryVector[i]; + } + return query; + } +} diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java new file mode 100644 index 000000000000..8281def61fb6 --- /dev/null +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search.join; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; + +public class TestParentBlockJoinFloatKnnVectorQuery extends ParentBlockJoinKnnVectorQueryTestCase { + + @Override + Query getParentJoinKnnQuery( + String fieldName, + float[] queryVector, + Query childFilter, + int k, + BitSetProducer parentBitSet) { + return new ToParentBlockJoinFloatKnnVectorQuery( + fieldName, queryVector, childFilter, k, parentBitSet); + } + + public void testScoreCosine() throws IOException { + try (Directory d = newDirectory()) { + try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) { + for (int j = 1; j <= 5; j++) { + List toAdd = new ArrayList<>(); + Document doc = new Document(); + doc.add(getKnnVectorField("field", new float[] {j, j * j}, COSINE)); + toAdd.add(doc); + toAdd.add(makeParent(new int[] {j})); + w.addDocuments(toAdd); + } + } + try (IndexReader reader = DirectoryReader.open(d)) { + assertEquals(1, reader.leaves().size()); + IndexSearcher searcher = new IndexSearcher(reader); + BitSetProducer parentFilter = parentFilter(searcher.getIndexReader()); + ToParentBlockJoinFloatKnnVectorQuery query = + new ToParentBlockJoinFloatKnnVectorQuery( + "field", new float[] {2, 3}, null, 3, parentFilter); + /* score0 = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then + * normalized by (1 + x) /2. + */ + float score0 = + (float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2); + + /* score1 = ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then + * normalized by (1 + x) /2 + */ + float score1 = + (float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2); + + assertScorerResults( + searcher, query, new float[] {score0, score1}, new String[] {"[1]", "[2]"}); + } + } + } + + @Override + Field getKnnVectorField(String name, float[] vector) { + return new KnnFloatVectorField(name, vector); + } + + @Override + Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction vectorSimilarityFunction) { + return new KnnFloatVectorField(name, vector, vectorSimilarityFunction); + } +} diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestToParentJoinKnnResults.java b/lucene/join/src/test/org/apache/lucene/search/join/TestToParentJoinKnnResults.java new file mode 100644 index 000000000000..fdac59fc3605 --- /dev/null +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestToParentJoinKnnResults.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search.join; + +import java.io.IOException; +import java.util.Arrays; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.BitSet; + +public class TestToParentJoinKnnResults extends LuceneTestCase { + + public void testNeighborsProduct() throws IOException { + // make sure we have the sign correct + BitSet parentBitSet = BitSet.of(new IntArrayDocIdSetIterator(new int[] {1, 3, 5}, 3), 6); + ToParentJoinKnnCollector nn = new ToParentJoinKnnCollector(2, Integer.MAX_VALUE, parentBitSet); + assertTrue(nn.collect(2, 0.5f)); + assertTrue(nn.collect(0, 0.2f)); + assertTrue(nn.collect(4, 1f)); + assertEquals(0.5f, nn.minCompetitiveSimilarity(), 0); + TopDocs topDocs = nn.topDocs(); + assertEquals(topDocs.scoreDocs[0].score, 1f, 0); + assertEquals(topDocs.scoreDocs[1].score, 0.5f, 0); + } + + public void testInsertions() throws IOException { + int[] nodes = new int[] {4, 1, 5, 7, 8, 10, 2}; + float[] scores = new float[] {1f, 0.5f, 0.6f, 2f, 2f, 1.2f, 4f}; + BitSet parentBitSet = BitSet.of(new IntArrayDocIdSetIterator(new int[] {3, 6, 9, 12}, 4), 13); + ToParentJoinKnnCollector results = + new ToParentJoinKnnCollector(7, Integer.MAX_VALUE, parentBitSet); + for (int i = 0; i < nodes.length; i++) { + results.collect(nodes[i], scores[i]); + } + TopDocs topDocs = results.topDocs(); + int[] sortedNodes = new int[topDocs.scoreDocs.length]; + float[] sortedScores = new float[topDocs.scoreDocs.length]; + for (int i = 0; i < topDocs.scoreDocs.length; i++) { + sortedNodes[i] = topDocs.scoreDocs[i].doc; + sortedScores[i] = topDocs.scoreDocs[i].score; + } + assertArrayEquals(new int[] {3, 9, 12, 6}, sortedNodes); + assertArrayEquals(new float[] {4f, 2f, 1.2f, 1f}, sortedScores, 0f); + } + + public void testInsertionWithOverflow() throws IOException { + int[] nodes = new int[] {4, 1, 5, 7, 8, 10, 2, 12, 14}; + float[] scores = new float[] {1f, 0.5f, 0.6f, 2f, 2f, 3f, 4f, 1f, 0.2f}; + BitSet parentBitSet = + BitSet.of(new IntArrayDocIdSetIterator(new int[] {3, 6, 9, 11, 13, 15}, 6), 16); + ToParentJoinKnnCollector results = + new ToParentJoinKnnCollector(5, Integer.MAX_VALUE, parentBitSet); + for (int i = 0; i < nodes.length - 1; i++) { + results.collect(nodes[i], scores[i]); + } + assertFalse(results.collect(nodes[nodes.length - 1], scores[nodes.length - 1])); + int[] sortedNodes = new int[5]; + float[] sortedScores = new float[5]; + TopDocs topDocs = results.topDocs(); + for (int i = 0; i < topDocs.scoreDocs.length; i++) { + sortedNodes[i] = topDocs.scoreDocs[i].doc; + sortedScores[i] = topDocs.scoreDocs[i].score; + } + assertArrayEquals(new int[] {3, 11, 9, 6, 13}, sortedNodes); + assertArrayEquals(new float[] {4f, 3f, 2f, 1f, 1f}, sortedScores, 0f); + } + + static class IntArrayDocIdSetIterator extends DocIdSetIterator { + + private final int[] docs; + private final int length; + private int i = 0; + private int doc = -1; + + IntArrayDocIdSetIterator(int[] docs, int length) { + this.docs = docs; + this.length = length; + } + + @Override + public int docID() { + return doc; + } + + @Override + public int nextDoc() throws IOException { + if (i >= length) { + return NO_MORE_DOCS; + } + return doc = docs[i++]; + } + + @Override + public int advance(int target) throws IOException { + int bound = 1; + // given that we use this for small arrays only, this is very unlikely to overflow + while (i + bound < length && docs[i + bound] < target) { + bound *= 2; + } + i = Arrays.binarySearch(docs, i + bound / 2, Math.min(i + bound + 1, length), target); + if (i < 0) { + i = -1 - i; + } + return doc = docs[i++]; + } + + @Override + public long cost() { + return length; + } + } +} diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index 302268031638..ea2dbcb86fe3 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -40,11 +40,11 @@ import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.SimpleCollector; -import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.store.Directory; import org.apache.lucene.util.ArrayUtil; @@ -1637,16 +1637,12 @@ public ByteVectorValues getByteVectorValues(String fieldName) { } @Override - public TopDocs searchNearestVectors( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) { - return null; - } + public void searchNearestVectors( + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override - public TopDocs searchNearestVectors( - String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) { - return null; - } + public void searchNearestVectors( + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override public void checkIntegrity() throws IOException { diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java index b00d1928827b..76f6b8a27c2a 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java @@ -31,7 +31,7 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.Bits; @@ -139,29 +139,23 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldInfo fi = fis.fieldInfo(field); assert fi != null && fi.getVectorDimension() > 0 && fi.getVectorEncoding() == VectorEncoding.FLOAT32; - TopDocs hits = delegate.search(field, target, k, acceptDocs, visitedLimit); - assert hits != null; - assert hits.scoreDocs.length <= k; - return hits; + delegate.search(field, target, knnCollector, acceptDocs); } @Override - public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldInfo fi = fis.fieldInfo(field); assert fi != null && fi.getVectorDimension() > 0 && fi.getVectorEncoding() == VectorEncoding.BYTE; - TopDocs hits = delegate.search(field, target, k, acceptDocs, visitedLimit); - assert hits != null; - assert hits.scoreDocs.length <= k; - return hits; + delegate.search(field, target, knnCollector, acceptDocs); } @Override diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java index 085f32d0078b..e36c1c5c7e85 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java @@ -43,7 +43,7 @@ import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.TermVectors; import org.apache.lucene.index.Terms; -import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; /** @@ -241,15 +241,15 @@ public ByteVectorValues getByteVectorValues(String fieldName) throws IOException } @Override - public TopDocs searchNearestVectors( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { - return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit); + public void searchNearestVectors( + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override - public TopDocs searchNearestVectors( - String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { - return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit); + public void searchNearestVectors( + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MismatchedLeafReader.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MismatchedLeafReader.java index bbf2adbdc046..538fa7c8c157 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MismatchedLeafReader.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MismatchedLeafReader.java @@ -28,6 +28,8 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.StoredFieldVisitor; import org.apache.lucene.index.StoredFields; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.util.Bits; /** * Shuffles field numbers around to try to trip bugs where field numbers are assumed to always be @@ -73,6 +75,18 @@ public CacheHelper getReaderCacheHelper() { return in.getReaderCacheHelper(); } + @Override + public void searchNearestVectors( + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs); + } + + @Override + public void searchNearestVectors( + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs); + } + static FieldInfos shuffleInfos(FieldInfos infos, Random random) { // first, shuffle the order List shuffled = new ArrayList<>(); diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java index 4cf7dcd0a5dd..41636b6b7fa4 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java @@ -45,6 +45,7 @@ import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -53,7 +54,6 @@ import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.SimpleCollector; -import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Weight; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.Bits; @@ -237,16 +237,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public TopDocs searchNearestVectors( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) { - return null; - } + public void searchNearestVectors( + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override - public TopDocs searchNearestVectors( - String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) { - return null; - } + public void searchNearestVectors( + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override public FieldInfos getFieldInfos() {