Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ParentJoin KNN support #12434

Merged
merged 49 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
17491f3
Add join support for knn
benwtrent Jun 29, 2023
ab950d0
Adding de-duplicating neighborqueue
benwtrent Jun 29, 2023
7e3b95f
adding tests
benwtrent Jun 30, 2023
05ac40e
moving things around
benwtrent Jun 30, 2023
506c7b7
add queue results tests
benwtrent Jul 3, 2023
d57a0b8
more changes
benwtrent Jul 3, 2023
5b03cd3
iterating on api design
benwtrent Jul 6, 2023
3622492
Merge branch 'main' into feature/add-join-support-knn
benwtrent Jul 6, 2023
1481f74
adding new leaf function
benwtrent Jul 7, 2023
f7aa19f
adding to parent block vector query
benwtrent Jul 7, 2023
1992840
updating exact match for parent join vector query
benwtrent Jul 7, 2023
e431bd0
fixing ordinal lookup
benwtrent Jul 7, 2023
1d33c4c
fixing api
benwtrent Jul 10, 2023
147488c
Merge remote-tracking branch 'upstream/main' into feature/add-join-su…
benwtrent Jul 10, 2023
3f7ce5b
javadoc
benwtrent Jul 10, 2023
898d400
fixing tests and formatting
benwtrent Jul 11, 2023
83ff209
fixing tests
benwtrent Jul 11, 2023
b5efd9c
cleaning up interface
benwtrent Jul 12, 2023
f1230df
formatting and reverting unnecessary changes
benwtrent Jul 18, 2023
43fec9d
fixing bad changes
benwtrent Jul 18, 2023
72b57c8
removing gradle update
benwtrent Jul 18, 2023
e1109fd
changing variable declaration
benwtrent Jul 18, 2023
6b04bca
formatting
benwtrent Jul 19, 2023
28c0c88
adding more tests
benwtrent Jul 19, 2023
fb826de
further adjusting API, need to fix tests
benwtrent Jul 19, 2023
1bd1001
adjusting api further
benwtrent Jul 19, 2023
5c96027
fixing tests
benwtrent Jul 19, 2023
6681d88
removing extranious put
benwtrent Jul 20, 2023
c73cd8f
moving numVisited into the results provider and collector
benwtrent Jul 20, 2023
c8d2a32
improving parent join query
benwtrent Jul 20, 2023
c0a6a71
adding parent join for byte vectors
benwtrent Jul 24, 2023
f609e5e
reducing knnResults definition requirements and eagerly translating v…
benwtrent Jul 24, 2023
cbb291e
remove provider interface
benwtrent Jul 24, 2023
3d7f930
moving things around, simplifying leaf reader
benwtrent Jul 25, 2023
9d8a2dd
updating param
benwtrent Jul 25, 2023
a255474
removing isFull and renaming minSimilarity
benwtrent Jul 26, 2023
2fee99e
making visit limit a long
benwtrent Jul 26, 2023
7726200
refactoring
benwtrent Jul 26, 2023
f4962ec
refactoring
benwtrent Jul 26, 2023
3809225
refactoring
benwtrent Jul 26, 2023
690d0ce
formatting
benwtrent Jul 26, 2023
475d440
refactoring
benwtrent Jul 26, 2023
bcb8029
making experimental and adding changes
benwtrent Jul 27, 2023
24cd478
formatting
benwtrent Jul 27, 2023
54cbe0a
Merge remote-tracking branch 'upstream/main' into feature/add-join-su…
benwtrent Jul 27, 2023
8abebcb
addressing pr comments fixing conflict
benwtrent Jul 27, 2023
c5a790c
Merge remote-tracking branch 'upstream/main' into feature/add-join-su…
benwtrent Jul 31, 2023
9cf92e3
cleaning up
benwtrent Jul 31, 2023
a2693ba
Merge remote-tracking branch 'upstream/main' into feature/add-join-su…
benwtrent Aug 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ New Features
* GITHUB#12383: Introduced LeafCollector#finish, a hook that runs after
collection has finished running on a leaf. (Adrien Grand)

* GITHUB#12434: Add `KnnCollector` to `LeafReader` and `KnnVectorReader` so that custom collection of vector
search results can be provided. The first custom collector provides `ToParentBlockJoin[Float|Byte]KnnVectorQuery`
joining child vector documents with their parent documents. (Ben Trent)

Improvements
---------------------
* GITHUB#12374: Add CachingLeafSlicesSupplier to compute the LeafSlices for concurrent segment search (Sorabh Hamirwasia)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
import java.util.List;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;

/**
Expand Down Expand Up @@ -73,7 +74,7 @@ public List<TermAndBoost> getSynonyms(
LinkedList<TermAndBoost> result = new LinkedList<>();
float[] query = word2VecModel.vectorValue(term);
if (query != null) {
NeighborQueue synonyms =
KnnCollector synonyms =
HnswGraphSearcher.search(
query,
// The query vector is in the model. When looking for the top-k
Expand All @@ -85,16 +86,16 @@ public List<TermAndBoost> getSynonyms(
hnswGraph,
null,
Integer.MAX_VALUE);
TopDocs topDocs = synonyms.topDocs();

int size = synonyms.size();
for (int i = 0; i < size; i++) {
float similarity = synonyms.topScore();
int id = synonyms.pop();
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
float similarity = topDocs.scoreDocs[i].score;
int id = topDocs.scoreDocs[i].doc;

BytesRef synonym = word2VecModel.termValue(id);
// We remove the original query term
if (!synonym.equals(term) && similarity >= minAcceptedSimilarity) {
result.addFirst(new TermAndBoost(synonym, similarity));
result.addLast(new TermAndBoost(synonym, similarity));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
Expand Down Expand Up @@ -237,48 +235,39 @@ public ByteVectorValues getByteVectorValues(String field) {
}

@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

if (fieldEntry.size() == 0) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
return;
}

// bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size());

OffHeapFloatVectorValues vectorValues = getOffHeapVectorValues(fieldEntry);
// use a seed that is fixed for the index so we get reproducible results for the same query
final SplittableRandom random = new SplittableRandom(checksumSeed);
NeighborQueue results =
Lucene90OnHeapHnswGraph.search(
target,
k,
k,
knnCollector.k(),
knnCollector.k(),
vectorValues,
fieldEntry.similarityFunction,
getGraphValues(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry),
visitedLimit,
knnCollector.visitLimit(),
random);
int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
knnCollector.incVisitedCount(results.visitedCount());
while (results.size() > 0) {
int node = results.topNode();
float minSimilarity = results.topScore();
results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], minSimilarity);
knnCollector.collect(node, minSimilarity);
}
TotalHits.Relation relation =
results.incomplete()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}

@Override
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public static NeighborQueue search(
VectorSimilarityFunction similarityFunction,
HnswGraph graphValues,
Bits acceptOrds,
int visitedLimit,
long visitedLimit,
SplittableRandom random)
throws IOException {
int size = graphValues.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
Expand All @@ -46,7 +44,6 @@
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

/**
Expand Down Expand Up @@ -229,47 +226,28 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

if (fieldEntry.size() == 0) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
return;
}

// bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size());
OffHeapFloatVectorValues vectorValues = getOffHeapVectorValues(fieldEntry);

NeighborQueue results =
HnswGraphSearcher.search(
target,
k,
vectorValues,
VectorEncoding.FLOAT32,
fieldEntry.similarityFunction,
getGraph(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry),
visitedLimit);

int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
while (results.size() > 0) {
int node = results.topNode();
float minSimilarity = results.topScore();
results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc(node), minSimilarity);
}

TotalHits.Relation relation =
results.incomplete()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
HnswGraphSearcher.search(
target,
knnCollector,
vectorValues,
VectorEncoding.FLOAT32,
fieldEntry.similarityFunction,
getGraph(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry));
}

@Override
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
Expand All @@ -45,7 +43,6 @@
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.packed.DirectMonotonicReader;

/**
Expand Down Expand Up @@ -225,47 +222,28 @@ public ByteVectorValues getByteVectorValues(String field) {
}

@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

if (fieldEntry.size() == 0) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
return;
}

// bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size());
OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData);

NeighborQueue results =
HnswGraphSearcher.search(
target,
k,
vectorValues,
VectorEncoding.FLOAT32,
fieldEntry.similarityFunction,
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs),
visitedLimit);

int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
while (results.size() > 0) {
int node = results.topNode();
float minSimilarity = results.topScore();
results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), minSimilarity);
}

TotalHits.Relation relation =
results.incomplete()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
HnswGraphSearcher.search(
target,
knnCollector,
vectorValues,
VectorEncoding.FLOAT32,
fieldEntry.similarityFunction,
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
}

@Override
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ public float[] vectorValue(int targetOrd) throws IOException {
return value;
}

public abstract int ordToDoc(int ord);

static OffHeapFloatVectorValues load(
Lucene92HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
if (fieldEntry.docsWithFieldOffset == -2) {
Expand Down
Loading