Skip to content

Commit

Permalink
Revert "tentative improvements to distinguish OnHeap / OffHeap search…
Browse files Browse the repository at this point in the history
…ers"

This reverts commit 7b91173.
  • Loading branch information
zhaih committed May 9, 2023
1 parent 139cde9 commit 5520eea
Show file tree
Hide file tree
Showing 13 changed files with 162 additions and 204 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TermAndBoost;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.OnHeapHnswGraphSearcher;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;

/**
Expand All @@ -43,7 +43,7 @@ public class Word2VecSynonymProvider {
VectorSimilarityFunction.DOT_PRODUCT;
private static final VectorEncoding VECTOR_ENCODING = VectorEncoding.FLOAT32;
private final Word2VecModel word2VecModel;
private final OnHeapHnswGraph hnswGraph;
private final HnswGraph hnswGraph;

/**
* Word2VecSynonymProvider constructor
Expand Down Expand Up @@ -75,7 +75,7 @@ public List<TermAndBoost> getSynonyms(
float[] query = word2VecModel.vectorValue(term);
if (query != null) {
NeighborQueue synonyms =
OnHeapHnswGraphSearcher.search(
HnswGraphSearcher.search(
query,
// The query vector is in the model. When looking for the top-k
// it's always the nearest neighbour of itself so, we look for the top-k+1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.OffHeapHnswGraphSearcher;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

Expand Down Expand Up @@ -241,7 +242,7 @@ public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int
OffHeapFloatVectorValues vectorValues = getOffHeapVectorValues(fieldEntry);

NeighborQueue results =
OffHeapHnswGraphSearcher.search(
HnswGraphSearcher.search(
target,
k,
vectorValues,
Expand Down Expand Up @@ -301,7 +302,7 @@ public int length() {
};
}

private OffHeapHnswGraph getGraph(FieldEntry entry) throws IOException {
private HnswGraph getGraph(FieldEntry entry) throws IOException {
IndexInput bytesSlice =
vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength);
return new OffHeapHnswGraph(entry, bytesSlice);
Expand Down Expand Up @@ -493,7 +494,7 @@ public float[] vectorValue(int targetOrd) throws IOException {
}

/** Read the nearest-neighbors graph from the index input */
private static final class OffHeapHnswGraph extends org.apache.lucene.util.hnsw.OffHeapHnswGraph {
private static final class OffHeapHnswGraph extends HnswGraph {

final IndexInput dataIn;
final int[][] nodesByLevel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.OffHeapHnswGraphSearcher;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.packed.DirectMonotonicReader;

Expand Down Expand Up @@ -237,7 +238,7 @@ public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int
OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData);

NeighborQueue results =
OffHeapHnswGraphSearcher.search(
HnswGraphSearcher.search(
target,
k,
vectorValues,
Expand Down Expand Up @@ -269,7 +270,7 @@ public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int v
throw new UnsupportedOperationException();
}

private OffHeapHnswGraph getGraph(FieldEntry entry) throws IOException {
private HnswGraph getGraph(FieldEntry entry) throws IOException {
IndexInput bytesSlice =
vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength);
return new OffHeapHnswGraph(entry, bytesSlice);
Expand Down Expand Up @@ -385,7 +386,7 @@ int size() {
}

/** Read the nearest-neighbors graph from the index input */
private static final class OffHeapHnswGraph extends org.apache.lucene.util.hnsw.OffHeapHnswGraph {
private static final class OffHeapHnswGraph extends HnswGraph {

final IndexInput dataIn;
final int[][] nodesByLevel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.OffHeapHnswGraphSearcher;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.packed.DirectMonotonicReader;

Expand Down Expand Up @@ -273,7 +274,7 @@ public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int
OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData);

NeighborQueue results =
OffHeapHnswGraphSearcher.search(
HnswGraphSearcher.search(
target,
k,
vectorValues,
Expand Down Expand Up @@ -313,7 +314,7 @@ public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int v
OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData);

NeighborQueue results =
OffHeapHnswGraphSearcher.search(
HnswGraphSearcher.search(
target,
k,
vectorValues,
Expand All @@ -339,7 +340,7 @@ public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int v
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
}

private OffHeapHnswGraph getGraph(FieldEntry entry) throws IOException {
private HnswGraph getGraph(FieldEntry entry) throws IOException {
IndexInput bytesSlice =
vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength);
return new OffHeapHnswGraph(entry, bytesSlice);
Expand Down Expand Up @@ -461,7 +462,7 @@ int size() {
}

/** Read the nearest-neighbors graph from the index input */
private static final class OffHeapHnswGraph extends org.apache.lucene.util.hnsw.OffHeapHnswGraph {
private static final class OffHeapHnswGraph extends HnswGraph {

final IndexInput dataIn;
final int[][] nodesByLevel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.OnHeapHnswGraphSearcher;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

Expand All @@ -56,7 +56,7 @@ public final class Lucene91HnswGraphBuilder {
private final RandomAccessVectorValues<float[]> vectorValues;
private final SplittableRandom random;
private final Lucene91BoundsChecker bound;
private final OnHeapHnswGraphSearcher<float[]> graphSearcher;
private final HnswGraphSearcher<float[]> graphSearcher;

final Lucene91OnHeapHnswGraph hnsw;

Expand Down Expand Up @@ -102,7 +102,7 @@ public Lucene91HnswGraphBuilder(
int levelOfFirstNode = getRandomGraphLevel(ml, random);
this.hnsw = new Lucene91OnHeapHnswGraph(maxConn, levelOfFirstNode);
this.graphSearcher =
new OnHeapHnswGraphSearcher<>(
new HnswGraphSearcher<>(
VectorEncoding.FLOAT32,
similarityFunction,
new NeighborQueue(beamWidth, true),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.OffHeapHnswGraphSearcher;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.packed.DirectMonotonicReader;

Expand Down Expand Up @@ -282,7 +282,7 @@ public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int
OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData);

NeighborQueue results =
OffHeapHnswGraphSearcher.search(
HnswGraphSearcher.search(
target,
k,
vectorValues,
Expand Down Expand Up @@ -325,7 +325,7 @@ public TopDocs search(String field, byte[] target, int k, Bits acceptDocs, int v
OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData);

NeighborQueue results =
OffHeapHnswGraphSearcher.search(
HnswGraphSearcher.search(
target,
k,
vectorValues,
Expand Down Expand Up @@ -365,7 +365,7 @@ public HnswGraph getGraph(String field) throws IOException {
}
}

private OffHeapHnswGraph getGraph(FieldEntry entry) throws IOException {
private HnswGraph getGraph(FieldEntry entry) throws IOException {
return new OffHeapHnswGraph(entry, vectorIndex);
}

Expand Down Expand Up @@ -489,7 +489,7 @@ public long ramBytesUsed() {
}

/** Read the nearest-neighbors graph from the index input */
private static final class OffHeapHnswGraph extends org.apache.lucene.util.hnsw.OffHeapHnswGraph {
private static final class OffHeapHnswGraph extends HnswGraph {

final IndexInput dataIn;
final int[][] nodesByLevel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public final class HnswGraphBuilder<T> {
private final VectorEncoding vectorEncoding;
private final RandomAccessVectorValues<T> vectors;
private final SplittableRandom random;
private final OnHeapHnswGraphSearcher<T> graphSearcher;
private final HnswGraphSearcher<T> graphSearcher;

final OnHeapHnswGraph hnsw;

Expand Down Expand Up @@ -142,7 +142,7 @@ private HnswGraphBuilder(
this.random = new SplittableRandom(seed);
this.hnsw = new OnHeapHnswGraph(M);
this.graphSearcher =
new OnHeapHnswGraphSearcher<>(
new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(beamWidth, true),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,129 @@
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.SparseFixedBitSet;

/**
* Searches an HNSW graph to find nearest neighbors to a query vector. For more background on the
* search algorithm, see {@link HnswGraph}.
*
* @param <T> the type of query vector
*/
public abstract class HnswGraphSearcher<T> {
VectorSimilarityFunction similarityFunction;
VectorEncoding vectorEncoding;
public class HnswGraphSearcher<T> {
private final VectorSimilarityFunction similarityFunction;
private final VectorEncoding vectorEncoding;

/**
* Scratch data structures that are used in each {@link #searchLevel} call. These can be expensive
* to allocate, so they're cleared and reused across calls.
*/
NeighborQueue candidates;
private final NeighborQueue candidates;

BitSet visited;
private BitSet visited;

/**
* Creates a new graph searcher.
*
* @param similarityFunction the similarity function to compare vectors
* @param candidates max heap that will track the candidate nodes to explore
* @param visited bit set that will track nodes that have already been visited
*/
public HnswGraphSearcher(
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
NeighborQueue candidates,
BitSet visited) {
this.vectorEncoding = vectorEncoding;
this.similarityFunction = similarityFunction;
this.candidates = candidates;
this.visited = visited;
}

/**
* Searches HNSW graph for the nearest neighbors of a query vector.
*
* <p>Note: if you want to search {@link OnHeapHnswGraph} in a thread-safety manner, please
* consider using {@link OnHeapHnswGraphSearcher}
*
* @param query search query vector
* @param topK the number of nodes to be returned
* @param vectors the vector values
* @param similarityFunction the similarity function to compare vectors
* @param graph the graph values. May represent the entire graph, or a level in a hierarchical
* graph.
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
* {@code null} if they are all allowed to match.
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @return a priority queue holding the closest neighbors found
*/
public static NeighborQueue search(
float[] query,
int topK,
RandomAccessVectorValues<float[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
if (query.length != vectors.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
+ query.length
+ " differs from field dimension: "
+ vectors.dimension());
}
HnswGraphSearcher<float[]> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}

/**
* Searches HNSW graph for the nearest neighbors of a query vector.
*
* <p>Note: if you want to search {@link OnHeapHnswGraph} in a thread-safety manner, please
* consider using {@link OnHeapHnswGraphSearcher}
*
* @param query search query vector
* @param topK the number of nodes to be returned
* @param vectors the vector values
* @param similarityFunction the similarity function to compare vectors
* @param graph the graph values. May represent the entire graph, or a level in a hierarchical
* graph.
* @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or
* {@code null} if they are all allowed to match.
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @return a priority queue holding the closest neighbors found
*/
public static NeighborQueue search(
byte[] query,
int topK,
RandomAccessVectorValues<byte[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
HnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
if (query.length != vectors.dimension()) {
throw new IllegalArgumentException(
"vector query dimension: "
+ query.length
+ " differs from field dimension: "
+ vectors.dimension());
}
HnswGraphSearcher<byte[]> graphSearcher =
new HnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}

static <T> NeighborQueue search(
T query,
Expand Down Expand Up @@ -201,7 +306,9 @@ private void prepareScratchState(int capacity) {
*
* @throws IOException when seeking the graph
*/
abstract void graphSeek(HnswGraph graph, int level, int targetNode) throws IOException;
void graphSeek(HnswGraph graph, int level, int targetNode) throws IOException {
graph.seek(level, targetNode);
}

/**
* Get the next neighbor from the graph, you must call {@link #graphSeek(HnswGraph, int, int)}
Expand All @@ -211,5 +318,7 @@ private void prepareScratchState(int capacity) {
* @return see {@link HnswGraph#nextNeighbor()}
* @throws IOException when advance neighbors
*/
abstract int graphNextNeighbor(HnswGraph graph) throws IOException;
int graphNextNeighbor(HnswGraph graph) throws IOException {
return graph.nextNeighbor();
}
}

This file was deleted.

Loading

0 comments on commit 5520eea

Please sign in to comment.