Skip to content

Commit

Permalink
Add dynamic type support (float/double) for ParquetDenseVectorDocumen…
Browse files Browse the repository at this point in the history
…tGenerator (#2667)
  • Loading branch information
b8zhong authored Jan 3, 2025
1 parent ea48ec8 commit f5c6929
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.parquet.example.data.Group;
import org.apache.parquet.hadoop.ParquetReader;
import org.apache.parquet.hadoop.example.GroupReadSupport;
import org.apache.parquet.schema.PrimitiveType;

/**
* Collection class for managing Parquet dense vectors
Expand Down Expand Up @@ -83,7 +84,7 @@ public FileSegment<ParquetDenseVectorCollection.Document> createFileSegment(Buff
* Inner class representing a file segment for ParquetDenseVectorCollection.
*/
public static class Segment extends FileSegment<ParquetDenseVectorCollection.Document> {
private List<double[]> vectors; // List to store vectors from the Parquet file
private List<float[]> vectors; // List to store vectors from the Parquet file
private List<String> ids; // List to store document IDs
private ParquetReader<Group> reader;
private boolean readerInitialized;
Expand Down Expand Up @@ -152,19 +153,29 @@ protected synchronized void readNext() throws IOException, NoSuchElementExceptio
throw new NoSuchElementException("End of file reached");
}

// Read each record from the Parquet file
// Extract the docid (String) from the record
String docid = record.getString("docid", 0);
ids.add(docid);

// Extract the vector (double[]) from the record
Group vectorGroup = record.getGroup("vector", 0); // Access the 'vector' field
int vectorSize = vectorGroup.getFieldRepetitionCount(0); // Get the number of elements in the vector
double[] vector = new double[vectorSize];
// Extract the vector (double[]) from the record
Group vectorGroup = record.getGroup("vector", 0);// Access the 'vector' field
int vectorSize = vectorGroup.getFieldRepetitionCount(0);// Get the number of elements in the vector
float[] vector = new float[vectorSize];

Group firstElement = vectorGroup.getGroup(0, 0);
PrimitiveType.PrimitiveTypeName primitiveType = firstElement.getType().getFields().get(0).asPrimitiveType().getPrimitiveTypeName();
boolean isDouble = primitiveType.equals(PrimitiveType.PrimitiveTypeName.DOUBLE);
boolean isFloat = primitiveType.equals(PrimitiveType.PrimitiveTypeName.FLOAT);

if (!isDouble && !isFloat) {
throw new IllegalArgumentException(String.format("Vector elements must be either DOUBLE or FLOAT, found: %s", primitiveType));
}

// Single-pass read with conditional cast if needed
for (int i = 0; i < vectorSize; i++) {
Group listGroup = vectorGroup.getGroup(0, i); // Access the 'list' group
vector[i] = listGroup.getDouble("element", 0); // Get the double value from the 'element' field
Group listGroup = vectorGroup.getGroup(0, i);
vector[i] = isDouble ? (float) listGroup.getDouble("element", 0) : listGroup.getFloat("element", 0);
}

vectors.add(vector);

// Create a new Document object with the retrieved data
Expand All @@ -177,7 +188,7 @@ protected synchronized void readNext() throws IOException, NoSuchElementExceptio
*/
public static class Document implements SourceDocument {
private final String id;
private final double[] vector;
private final float[] vector;
private final String raw;

/**
Expand All @@ -187,7 +198,7 @@ public static class Document implements SourceDocument {
* @param vector the vector data.
* @param raw the raw data.
*/
public Document(String id, double[] vector, String raw) {
public Document(String id, float[] vector, String raw) {
this.id = id;
this.vector = vector;
this.raw = raw;
Expand Down
21 changes: 21 additions & 0 deletions src/test/java/io/anserini/index/IndexFlatDenseVectorsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,25 @@ public void testQuantizedInt8() throws Exception {
assertNotNull(results);
assertEquals(100, results.get("documents"));
}

@Test
public void testParquetFloat() throws Exception {
String indexPath = "target/lucene-test-index.flat." + System.currentTimeMillis();
String[] indexArgs = new String[] {
"-collection", "ParquetDenseVectorCollection",
"-input", "src/test/resources/sample_docs/parquet/msmarco-passage-bge-base-en-v1.5.parquet-float",
"-index", indexPath,
"-generator", "ParquetDenseVectorDocumentGenerator",
"-threads", "1"
};

IndexFlatDenseVectors.main(indexArgs);

IndexReader reader = IndexReaderUtils.getReader(indexPath);
assertNotNull(reader);

Map<String, Object> results = IndexReaderUtils.getIndexStats(reader, Constants.VECTOR);
assertNotNull(results);
assertEquals(10, results.get("documents"));
}
}
22 changes: 22 additions & 0 deletions src/test/java/io/anserini/index/IndexHnswDenseVectorsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,28 @@ public void testParquet() throws Exception {
assertEquals(10, results.get("documents"));
}

@Test
public void testParquetFloat() throws Exception {
String indexPath = "target/lucene-test-index.flat." + System.currentTimeMillis();
String[] indexArgs = new String[] {
"-collection", "ParquetDenseVectorCollection",
"-input", "src/test/resources/sample_docs/parquet/msmarco-passage-bge-base-en-v1.5.parquet-float",
"-index", indexPath,
"-generator", "ParquetDenseVectorDocumentGenerator",
"-threads", "1",
"-M", "16", "-efC", "100"
};

IndexHnswDenseVectors.main(indexArgs);

IndexReader reader = IndexReaderUtils.getReader(indexPath);
assertNotNull(reader);

Map<String, Object> results = IndexReaderUtils.getIndexStats(reader, Constants.VECTOR);
assertNotNull(results);
assertEquals(10, results.get("documents"));
}

@Test
public void testQuantizedInt8() throws Exception {
String indexPath = "target/lucene-test-index.hnsw." + System.currentTimeMillis();
Expand Down
Binary file not shown.

0 comments on commit f5c6929

Please sign in to comment.