Skip to content

Commit

Permalink
Add ParentJoin KNN support (#12434)
Browse files Browse the repository at this point in the history
A `join` within Lucene is built by adding child-docs and parent-docs in order. Since our vector field already supports sparse indexing, it should be able to support parent join indexing.

However, when searching for the closest `k`, it is still the k nearest children vectors with no way to join back to the parent.

This commit adds this ability through some significant changes:
 - New leaf reader function that allows a collector for knn results
 - The knn results can then utilize bit-sets to join back to the parent id

This type of support is critical for nearest passage retrieval over larger documents. Generally, you want the top-k documents and knowledge of the nearest passages over each top-k document. Lucene's join functionality is a nice fit for this.

This does not replace the need for multi-valued vectors, which is important for other ranking methods (e.g. colbert token embeddings). But, it could be used in the case when metadata about the passage embedding must be stored (e.g. the related passage).
  • Loading branch information
benwtrent committed Aug 7, 2023
1 parent 4d2f8a3 commit 1f25c68
Show file tree
Hide file tree
Showing 58 changed files with 2,273 additions and 578 deletions.
4 changes: 4 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ New Features
* LUCENE-8183, GITHUB#9231: Added the abbility to get noSubMatches and noOverlappingMatches in
HyphenationCompoundWordFilter (Martin Demberger, original from Rupert Westenthaler)

* GITHUB#12434: Add `KnnCollector` to `LeafReader` and `KnnVectorReader` so that custom collection of vector
search results can be provided. The first custom collector provides `ToParentBlockJoin[Float|Byte]KnnVectorQuery`
joining child vector documents with their parent documents. (Ben Trent)

Improvements
---------------------
* GITHUB#12374: Add CachingLeafSlicesSupplier to compute the LeafSlices for concurrent segment search (Sorabh Hamirwasia)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
import java.util.List;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;

/**
Expand Down Expand Up @@ -73,7 +74,7 @@ public List<TermAndBoost> getSynonyms(
LinkedList<TermAndBoost> result = new LinkedList<>();
float[] query = word2VecModel.vectorValue(term);
if (query != null) {
NeighborQueue synonyms =
KnnCollector synonyms =
HnswGraphSearcher.search(
query,
// The query vector is in the model. When looking for the top-k
Expand All @@ -85,16 +86,16 @@ public List<TermAndBoost> getSynonyms(
hnswGraph,
null,
Integer.MAX_VALUE);
TopDocs topDocs = synonyms.topDocs();

int size = synonyms.size();
for (int i = 0; i < size; i++) {
float similarity = synonyms.topScore();
int id = synonyms.pop();
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
float similarity = topDocs.scoreDocs[i].score;
int id = topDocs.scoreDocs[i].doc;

BytesRef synonym = word2VecModel.termValue(id);
// We remove the original query term
if (!synonym.equals(term) && similarity >= minAcceptedSimilarity) {
result.addFirst(new TermAndBoost(synonym, similarity));
result.addLast(new TermAndBoost(synonym, similarity));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
Expand Down Expand Up @@ -237,48 +235,39 @@ public ByteVectorValues getByteVectorValues(String field) {
}

@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

if (fieldEntry.size() == 0) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
return;
}

// bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size());

OffHeapFloatVectorValues vectorValues = getOffHeapVectorValues(fieldEntry);
// use a seed that is fixed for the index so we get reproducible results for the same query
final SplittableRandom random = new SplittableRandom(checksumSeed);
NeighborQueue results =
Lucene90OnHeapHnswGraph.search(
target,
k,
k,
knnCollector.k(),
knnCollector.k(),
vectorValues,
fieldEntry.similarityFunction,
getGraphValues(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry),
visitedLimit,
knnCollector.visitLimit(),
random);
int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
knnCollector.incVisitedCount(results.visitedCount());
while (results.size() > 0) {
int node = results.topNode();
float minSimilarity = results.topScore();
results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], minSimilarity);
knnCollector.collect(node, minSimilarity);
}
TotalHits.Relation relation =
results.incomplete()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}

@Override
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public static NeighborQueue search(
VectorSimilarityFunction similarityFunction,
HnswGraph graphValues,
Bits acceptOrds,
int visitedLimit,
long visitedLimit,
SplittableRandom random)
throws IOException {
int size = graphValues.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
Expand All @@ -46,7 +44,6 @@
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

/**
Expand Down Expand Up @@ -229,47 +226,28 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
}

@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

if (fieldEntry.size() == 0) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
return;
}

// bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size());
OffHeapFloatVectorValues vectorValues = getOffHeapVectorValues(fieldEntry);

NeighborQueue results =
HnswGraphSearcher.search(
target,
k,
vectorValues,
VectorEncoding.FLOAT32,
fieldEntry.similarityFunction,
getGraph(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry),
visitedLimit);

int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
while (results.size() > 0) {
int node = results.topNode();
float minSimilarity = results.topScore();
results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc(node), minSimilarity);
}

TotalHits.Relation relation =
results.incomplete()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
HnswGraphSearcher.search(
target,
knnCollector,
vectorValues,
VectorEncoding.FLOAT32,
fieldEntry.similarityFunction,
getGraph(fieldEntry),
getAcceptOrds(acceptDocs, fieldEntry));
}

@Override
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
Expand All @@ -45,7 +43,6 @@
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.packed.DirectMonotonicReader;

/**
Expand Down Expand Up @@ -225,47 +222,28 @@ public ByteVectorValues getByteVectorValues(String field) {
}

@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

if (fieldEntry.size() == 0) {
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
return;
}

// bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size());
OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData);

NeighborQueue results =
HnswGraphSearcher.search(
target,
k,
vectorValues,
VectorEncoding.FLOAT32,
fieldEntry.similarityFunction,
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs),
visitedLimit);

int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
while (results.size() > 0) {
int node = results.topNode();
float minSimilarity = results.topScore();
results.pop();
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(vectorValues.ordToDoc(node), minSimilarity);
}

TotalHits.Relation relation =
results.incomplete()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
HnswGraphSearcher.search(
target,
knnCollector,
vectorValues,
VectorEncoding.FLOAT32,
fieldEntry.similarityFunction,
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
}

@Override
public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int visitedLimit)
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ public float[] vectorValue(int targetOrd) throws IOException {
return value;
}

public abstract int ordToDoc(int ord);

static OffHeapFloatVectorValues load(
Lucene92HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
if (fieldEntry.docsWithFieldOffset == -2) {
Expand Down Expand Up @@ -118,11 +116,6 @@ public RandomAccessVectorValues<float[]> copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
}

@Override
public int ordToDoc(int ord) {
return ord;
}

@Override
Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
Expand Down
Loading

0 comments on commit 1f25c68

Please sign in to comment.