Skip to content

Commit

Permalink
Spark 3.5: Make ColumnVectorWithFilter generic and refactor ColumnarB…
Browse files Browse the repository at this point in the history
…atchReader
  • Loading branch information
aokolnychyi committed Jan 23, 2025
1 parent be6e9da commit 85419ce
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,78 +18,121 @@
*/
package org.apache.iceberg.spark.data.vectorized;

import org.apache.iceberg.arrow.vectorized.VectorHolder;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarArray;
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.types.UTF8String;

public class ColumnVectorWithFilter extends IcebergArrowColumnVector {
/**
* A column vector implementation that applies row-level filtering.
*
* <p>This class wraps an existing column vector and uses a row ID mapping array to remap row
* indices during data access. Each method that retrieves data for a specific row translates the
* provided row index using the mapping array, effectively filtering the original data to only
* expose the live subset of rows. This approach allows efficient row-level filtering without
* modifying the underlying data.
*/
public class ColumnVectorWithFilter extends ColumnVector {
private final ColumnVector delegate;
private final int[] rowIdMapping;

public ColumnVectorWithFilter(VectorHolder holder, int[] rowIdMapping) {
super(holder);
public ColumnVectorWithFilter(ColumnVector delegate, int[] rowIdMapping) {
super(delegate.dataType());
this.delegate = delegate;
this.rowIdMapping = rowIdMapping;
}

@Override
public void close() {
delegate.close();
}

@Override
public void closeIfFreeable() {
delegate.closeIfFreeable();
}

@Override
public boolean hasNull() {
return delegate.hasNull();
}

@Override
public int numNulls() {
// computing the actual number of nulls with rowIdMapping is expensive
// it is OK to overestimate and return the number of nulls in the original vector
return delegate.numNulls();
}

@Override
public boolean isNullAt(int rowId) {
return nullabilityHolder().isNullAt(rowIdMapping[rowId]) == 1;
return delegate.isNullAt(rowIdMapping[rowId]);
}

@Override
public boolean getBoolean(int rowId) {
return accessor().getBoolean(rowIdMapping[rowId]);
return delegate.getBoolean(rowIdMapping[rowId]);
}

@Override
public byte getByte(int rowId) {
return delegate.getByte(rowIdMapping[rowId]);
}

@Override
public short getShort(int rowId) {
return delegate.getShort(rowIdMapping[rowId]);
}

@Override
public int getInt(int rowId) {
return accessor().getInt(rowIdMapping[rowId]);
return delegate.getInt(rowIdMapping[rowId]);
}

@Override
public long getLong(int rowId) {
return accessor().getLong(rowIdMapping[rowId]);
return delegate.getLong(rowIdMapping[rowId]);
}

@Override
public float getFloat(int rowId) {
return accessor().getFloat(rowIdMapping[rowId]);
return delegate.getFloat(rowIdMapping[rowId]);
}

@Override
public double getDouble(int rowId) {
return accessor().getDouble(rowIdMapping[rowId]);
return delegate.getDouble(rowIdMapping[rowId]);
}

@Override
public ColumnarArray getArray(int rowId) {
if (isNullAt(rowId)) {
return null;
}
return accessor().getArray(rowIdMapping[rowId]);
return delegate.getArray(rowIdMapping[rowId]);
}

@Override
public ColumnarMap getMap(int rowId) {
return delegate.getMap(rowIdMapping[rowId]);
}

@Override
public Decimal getDecimal(int rowId, int precision, int scale) {
if (isNullAt(rowId)) {
return null;
}
return accessor().getDecimal(rowIdMapping[rowId], precision, scale);
return delegate.getDecimal(rowIdMapping[rowId], precision, scale);
}

@Override
public UTF8String getUTF8String(int rowId) {
if (isNullAt(rowId)) {
return null;
}
return accessor().getUTF8String(rowIdMapping[rowId]);
return delegate.getUTF8String(rowIdMapping[rowId]);
}

@Override
public byte[] getBinary(int rowId) {
if (isNullAt(rowId)) {
return null;
}
return accessor().getBinary(rowIdMapping[rowId]);
return delegate.getBinary(rowIdMapping[rowId]);
}

@Override
public ColumnVector getChild(int ordinal) {
ColumnVector child = delegate.getChild(ordinal);
return new ColumnVectorWithFilter(child, rowIdMapping);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,43 +94,42 @@ private class ColumnBatchLoader {
}

ColumnarBatch loadDataToColumnBatch() {
ColumnVector[] arrowColumnVectors = readDataToColumnVectors();
ColumnVector[] vectors = readDataToColumnVectors();
int numLiveRows = batchSize;

if (hasIsDeletedColumn) {
boolean[] isDeleted =
ColumnarBatchUtil.buildIsDeleted(
arrowColumnVectors, deletes, rowStartPosInBatch, batchSize);
for (int i = 0; i < arrowColumnVectors.length; i++) {
ColumnVector vector = arrowColumnVectors[i];
boolean[] isDeleted = buildIsDeleted(vectors);
for (ColumnVector vector : vectors) {
if (vector instanceof DeletedColumnVector) {
((DeletedColumnVector) vector).setValue(isDeleted);
}
}
} else {
Pair<int[], Integer> pair =
ColumnarBatchUtil.buildRowIdMapping(
arrowColumnVectors, deletes, rowStartPosInBatch, batchSize);
Pair<int[], Integer> pair = buildRowIdMapping(vectors);
if (pair != null) {
int[] rowIdMapping = pair.first();
numLiveRows = pair.second();
for (int i = 0; i < arrowColumnVectors.length; i++) {
ColumnVector vector = arrowColumnVectors[i];
if (vector instanceof IcebergArrowColumnVector) {
arrowColumnVectors[i] =
new ColumnVectorWithFilter(
((IcebergArrowColumnVector) vector).vector(), rowIdMapping);
}
for (int i = 0; i < vectors.length; i++) {
vectors[i] = new ColumnVectorWithFilter(vectors[i], rowIdMapping);
}
}
}

if (deletes != null && deletes.hasEqDeletes()) {
arrowColumnVectors = ColumnarBatchUtil.removeExtraColumns(deletes, arrowColumnVectors);
vectors = ColumnarBatchUtil.removeExtraColumns(deletes, vectors);
}

ColumnarBatch newColumnarBatch = new ColumnarBatch(arrowColumnVectors);
newColumnarBatch.setNumRows(numLiveRows);
return newColumnarBatch;
ColumnarBatch batch = new ColumnarBatch(vectors);
batch.setNumRows(numLiveRows);
return batch;
}

private boolean[] buildIsDeleted(ColumnVector[] vectors) {
return ColumnarBatchUtil.buildIsDeleted(vectors, deletes, rowStartPosInBatch, batchSize);
}

private Pair<int[], Integer> buildRowIdMapping(ColumnVector[] vectors) {
return ColumnarBatchUtil.buildRowIdMapping(vectors, deletes, rowStartPosInBatch, batchSize);
}

ColumnVector[] readDataToColumnVectors() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,11 @@ public class IcebergArrowColumnVector extends ColumnVector {

private final ArrowVectorAccessor<Decimal, UTF8String, ColumnarArray, ArrowColumnVector> accessor;
private final NullabilityHolder nullabilityHolder;
private final VectorHolder holder;

public IcebergArrowColumnVector(VectorHolder holder) {
super(SparkSchemaUtil.convert(holder.icebergType()));
this.nullabilityHolder = holder.nullabilityHolder();
this.accessor = ArrowVectorAccessors.getVectorAccessor(holder);
this.holder = holder;
}

public VectorHolder vector() {
return holder;
}

protected ArrowVectorAccessor<Decimal, UTF8String, ColumnarArray, ArrowColumnVector> accessor() {
Expand Down

0 comments on commit 85419ce

Please sign in to comment.