Skip to content

Commit

Permalink
feat(java): support statistics row num for lance scan (#3304)
Browse files Browse the repository at this point in the history
Support statistics row num for lance scan, and with this statistics the
spark will choose the broadcast to join for a small table.
But now the byte size of lance dataset is inferred by row number. It is
not very precise.
Maybe We should store the file size in meta as disscus in #3221.
  • Loading branch information
SaintBacchus authored Jan 3, 2025
1 parent 8fe7147 commit 8a23d50
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 8 deletions.
18 changes: 14 additions & 4 deletions java/core/lance-jni/src/blocking_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -705,14 +705,24 @@ fn inner_latest_version(env: &mut JNIEnv, java_dataset: JObject) -> Result<u64>
pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeCountRows(
mut env: JNIEnv,
java_dataset: JObject,
) -> jint {
ok_or_throw_with_return!(env, inner_count_rows(&mut env, java_dataset), -1) as jint
filter_jobj: JObject, // Optional<String>
) -> jlong {
ok_or_throw_with_return!(
env,
inner_count_rows(&mut env, java_dataset, filter_jobj),
-1
) as jlong
}

fn inner_count_rows(env: &mut JNIEnv, java_dataset: JObject) -> Result<usize> {
fn inner_count_rows(
env: &mut JNIEnv,
java_dataset: JObject,
filter_jobj: JObject,
) -> Result<usize> {
let filter = env.get_string_opt(&filter_jobj)?;
let dataset_guard =
unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?;
dataset_guard.count_rows(None)
dataset_guard.count_rows(filter)
}

#[no_mangle]
Expand Down
21 changes: 18 additions & 3 deletions java/core/src/main/java/com/lancedb/lance/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -425,14 +425,29 @@ private native void nativeCreateIndex(
*
* @return num of rows
*/
public int countRows() {
public long countRows() {
try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) {
Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed");
return nativeCountRows();
return nativeCountRows(Optional.empty());
}
}

private native int nativeCountRows();
/**
* Count the number of rows in the dataset.
*
* @param filter the filter expr to count row
* @return num of rows
*/
public long countRows(String filter) {
try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) {
Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed");
Preconditions.checkArgument(
null != filter && !filter.isEmpty(), "filter cannot be null or empty");
return nativeCountRows(Optional.of(filter));
}
}

private native long nativeCountRows(Optional<String> filter);

/**
* Get all fragments in this dataset.
Expand Down
20 changes: 20 additions & 0 deletions java/core/src/test/java/com/lancedb/lance/DatasetTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -336,4 +336,24 @@ void testTake() throws IOException, ClosedChannelException {
}
}
}

@Test
void testCountRows() {
String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName();
String datasetPath = tempDir.resolve(testMethodName).toString();
try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
TestUtils.SimpleTestDataset testDataset =
new TestUtils.SimpleTestDataset(allocator, datasetPath);
dataset = testDataset.createEmptyDataset();

try (Dataset dataset2 = testDataset.write(1, 5)) {
assertEquals(5, dataset2.countRows());
// get id = 3 and 4
assertEquals(2, dataset2.countRows("id > 2"));

assertThrows(IllegalArgumentException.class, () -> dataset2.countRows(null));
assertThrows(IllegalArgumentException.class, () -> dataset2.countRows(""));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ public static Optional<StructType> getSchema(String datasetUri) {
}
}

public static Optional<Long> getDatasetRowCount(LanceConfig config) {
String uri = config.getDatasetUri();
ReadOptions options = SparkOptions.genReadOptionFromConfig(config);
try (Dataset dataset = Dataset.open(allocator, uri, options)) {
return Optional.of(dataset.countRows());
} catch (IllegalArgumentException e) {
// dataset not found
return Optional.empty();
}
}

public static List<Integer> getFragmentIds(LanceConfig config) {
String uri = config.getDatasetUri();
ReadOptions options = SparkOptions.genReadOptionFromConfig(config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.apache.spark.sql.connector.read.PartitionReader;
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
import org.apache.spark.sql.connector.read.Scan;
import org.apache.spark.sql.connector.read.Statistics;
import org.apache.spark.sql.connector.read.SupportsReportStatistics;
import org.apache.spark.sql.internal.connector.SupportsMetadata;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.vectorized.ColumnarBatch;
Expand All @@ -35,7 +37,8 @@
import java.util.List;
import java.util.stream.IntStream;

public class LanceScan implements Batch, Scan, SupportsMetadata, Serializable {
public class LanceScan
implements Batch, Scan, SupportsMetadata, SupportsReportStatistics, Serializable {
private static final long serialVersionUID = 947284762748623947L;

private final StructType schema;
Expand Down Expand Up @@ -103,6 +106,11 @@ public Map<String, String> getMetaData() {
return hashMap.toMap(scala.Predef.conforms());
}

@Override
public Statistics estimateStatistics() {
return new LanceStatistics(config);
}

private class LanceReaderFactory implements PartitionReaderFactory {
@Override
public PartitionReader<InternalRow> createReader(InputPartition partition) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.lancedb.lance.spark.read;

import com.lancedb.lance.spark.LanceConfig;
import com.lancedb.lance.spark.internal.LanceDatasetAdapter;
import com.lancedb.lance.spark.utils.Optional;

import org.apache.spark.sql.connector.read.Statistics;
import org.apache.spark.sql.types.StructType;

import java.util.OptionalLong;

public class LanceStatistics implements Statistics {
private final Optional<Long> rowNumber;
private final Optional<StructType> schema;

public LanceStatistics(LanceConfig config) {
this.rowNumber = LanceDatasetAdapter.getDatasetRowCount(config);
this.schema = LanceDatasetAdapter.getSchema(config);
}

@Override
public OptionalLong sizeInBytes() {
// TODO: Support quickly get the bytes on disk for the lance dataset
// Now use schema to infer the byte size for simple
if (rowNumber.isPresent()) {
return OptionalLong.of(schema.get().defaultSize() * rowNumber.get());
} else {
return OptionalLong.empty();
}
}

@Override
public OptionalLong numRows() {
if (rowNumber.isPresent()) {
return OptionalLong.of(rowNumber.get());
} else {
return OptionalLong.empty();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.stream.Collectors;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class SparkConnectorReadTest {
private static SparkSession spark;
Expand All @@ -53,6 +54,7 @@ static void setup() {
LanceConfig.CONFIG_DATASET_URI,
LanceConfig.getDatasetUri(dbPath, TestUtils.TestTable1Config.datasetName))
.load();
data.createOrReplaceTempView("test_dataset1");
}

@AfterAll
Expand Down Expand Up @@ -171,4 +173,17 @@ public void supportDataSourceLoadPath() {
.load(LanceConfig.getDatasetUri(dbPath, TestUtils.TestTable1Config.datasetName));
validateData(df, TestUtils.TestTable1Config.expectedValues);
}

@Test
public void supportBroadcastJoin() {
Dataset<Row> df =
spark.read().format("lance").load(LanceConfig.getDatasetUri(dbPath, "test_dataset3"));
df.createOrReplaceTempView("test_dataset3");
List<Row> desc =
spark
.sql("explain select t1.* from test_dataset1 t1 join test_dataset3 t3 on t1.x = t3.x")
.collectAsList();
assertEquals(1, desc.size());
assertTrue(desc.get(0).getString(0).contains("BroadcastHashJoin"));
}
}

0 comments on commit 8a23d50

Please sign in to comment.