From 8a23d50f630deba1f80db7c96c68cd33994ed391 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Sat, 4 Jan 2025 07:25:14 +0800 Subject: [PATCH] feat(java): support statistics row num for lance scan (#3304) Support statistics row num for lance scan, and with this statistics the spark will choose the broadcast to join for a small table. But now the byte size of lance dataset is inferred by row number. It is not very precise. Maybe We should store the file size in meta as disscus in #3221. --- java/core/lance-jni/src/blocking_dataset.rs | 18 +++++-- .../main/java/com/lancedb/lance/Dataset.java | 21 ++++++-- .../java/com/lancedb/lance/DatasetTest.java | 20 +++++++ .../spark/internal/LanceDatasetAdapter.java | 11 ++++ .../lancedb/lance/spark/read/LanceScan.java | 10 +++- .../lance/spark/read/LanceStatistics.java | 54 +++++++++++++++++++ .../spark/read/SparkConnectorReadTest.java | 15 ++++++ 7 files changed, 141 insertions(+), 8 deletions(-) create mode 100644 java/spark/src/main/java/com/lancedb/lance/spark/read/LanceStatistics.java diff --git a/java/core/lance-jni/src/blocking_dataset.rs b/java/core/lance-jni/src/blocking_dataset.rs index a52a5c1069..94764751d1 100644 --- a/java/core/lance-jni/src/blocking_dataset.rs +++ b/java/core/lance-jni/src/blocking_dataset.rs @@ -705,14 +705,24 @@ fn inner_latest_version(env: &mut JNIEnv, java_dataset: JObject) -> Result pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeCountRows( mut env: JNIEnv, java_dataset: JObject, -) -> jint { - ok_or_throw_with_return!(env, inner_count_rows(&mut env, java_dataset), -1) as jint + filter_jobj: JObject, // Optional +) -> jlong { + ok_or_throw_with_return!( + env, + inner_count_rows(&mut env, java_dataset, filter_jobj), + -1 + ) as jlong } -fn inner_count_rows(env: &mut JNIEnv, java_dataset: JObject) -> Result { +fn inner_count_rows( + env: &mut JNIEnv, + java_dataset: JObject, + filter_jobj: JObject, +) -> Result { + let filter = env.get_string_opt(&filter_jobj)?; let dataset_guard = unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; - dataset_guard.count_rows(None) + dataset_guard.count_rows(filter) } #[no_mangle] diff --git a/java/core/src/main/java/com/lancedb/lance/Dataset.java b/java/core/src/main/java/com/lancedb/lance/Dataset.java index 0f7cb9920a..975c9b1d43 100644 --- a/java/core/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/core/src/main/java/com/lancedb/lance/Dataset.java @@ -425,14 +425,29 @@ private native void nativeCreateIndex( * * @return num of rows */ - public int countRows() { + public long countRows() { try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); - return nativeCountRows(); + return nativeCountRows(Optional.empty()); } } - private native int nativeCountRows(); + /** + * Count the number of rows in the dataset. + * + * @param filter the filter expr to count row + * @return num of rows + */ + public long countRows(String filter) { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + Preconditions.checkArgument( + null != filter && !filter.isEmpty(), "filter cannot be null or empty"); + return nativeCountRows(Optional.of(filter)); + } + } + + private native long nativeCountRows(Optional filter); /** * Get all fragments in this dataset. diff --git a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java index 4275ef9573..73e48d47d8 100644 --- a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java +++ b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java @@ -336,4 +336,24 @@ void testTake() throws IOException, ClosedChannelException { } } } + + @Test + void testCountRows() { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + dataset = testDataset.createEmptyDataset(); + + try (Dataset dataset2 = testDataset.write(1, 5)) { + assertEquals(5, dataset2.countRows()); + // get id = 3 and 4 + assertEquals(2, dataset2.countRows("id > 2")); + + assertThrows(IllegalArgumentException.class, () -> dataset2.countRows(null)); + assertThrows(IllegalArgumentException.class, () -> dataset2.countRows("")); + } + } + } } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java index b5938bc65c..111a94e6c4 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java @@ -57,6 +57,17 @@ public static Optional getSchema(String datasetUri) { } } + public static Optional getDatasetRowCount(LanceConfig config) { + String uri = config.getDatasetUri(); + ReadOptions options = SparkOptions.genReadOptionFromConfig(config); + try (Dataset dataset = Dataset.open(allocator, uri, options)) { + return Optional.of(dataset.countRows()); + } catch (IllegalArgumentException e) { + // dataset not found + return Optional.empty(); + } + } + public static List getFragmentIds(LanceConfig config) { String uri = config.getDatasetUri(); ReadOptions options = SparkOptions.genReadOptionFromConfig(config); diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScan.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScan.java index 9455e5c444..5935269708 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScan.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScan.java @@ -25,6 +25,8 @@ import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.connector.read.PartitionReaderFactory; import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsReportStatistics; import org.apache.spark.sql.internal.connector.SupportsMetadata; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -35,7 +37,8 @@ import java.util.List; import java.util.stream.IntStream; -public class LanceScan implements Batch, Scan, SupportsMetadata, Serializable { +public class LanceScan + implements Batch, Scan, SupportsMetadata, SupportsReportStatistics, Serializable { private static final long serialVersionUID = 947284762748623947L; private final StructType schema; @@ -103,6 +106,11 @@ public Map getMetaData() { return hashMap.toMap(scala.Predef.conforms()); } + @Override + public Statistics estimateStatistics() { + return new LanceStatistics(config); + } + private class LanceReaderFactory implements PartitionReaderFactory { @Override public PartitionReader createReader(InputPartition partition) { diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceStatistics.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceStatistics.java new file mode 100644 index 0000000000..5a6b41f9ad --- /dev/null +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceStatistics.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lancedb.lance.spark.read; + +import com.lancedb.lance.spark.LanceConfig; +import com.lancedb.lance.spark.internal.LanceDatasetAdapter; +import com.lancedb.lance.spark.utils.Optional; + +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.types.StructType; + +import java.util.OptionalLong; + +public class LanceStatistics implements Statistics { + private final Optional rowNumber; + private final Optional schema; + + public LanceStatistics(LanceConfig config) { + this.rowNumber = LanceDatasetAdapter.getDatasetRowCount(config); + this.schema = LanceDatasetAdapter.getSchema(config); + } + + @Override + public OptionalLong sizeInBytes() { + // TODO: Support quickly get the bytes on disk for the lance dataset + // Now use schema to infer the byte size for simple + if (rowNumber.isPresent()) { + return OptionalLong.of(schema.get().defaultSize() * rowNumber.get()); + } else { + return OptionalLong.empty(); + } + } + + @Override + public OptionalLong numRows() { + if (rowNumber.isPresent()) { + return OptionalLong.of(rowNumber.get()); + } else { + return OptionalLong.empty(); + } + } +} diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java index 1b3bbd372c..7628d92843 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java @@ -30,6 +30,7 @@ import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class SparkConnectorReadTest { private static SparkSession spark; @@ -53,6 +54,7 @@ static void setup() { LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath, TestUtils.TestTable1Config.datasetName)) .load(); + data.createOrReplaceTempView("test_dataset1"); } @AfterAll @@ -171,4 +173,17 @@ public void supportDataSourceLoadPath() { .load(LanceConfig.getDatasetUri(dbPath, TestUtils.TestTable1Config.datasetName)); validateData(df, TestUtils.TestTable1Config.expectedValues); } + + @Test + public void supportBroadcastJoin() { + Dataset df = + spark.read().format("lance").load(LanceConfig.getDatasetUri(dbPath, "test_dataset3")); + df.createOrReplaceTempView("test_dataset3"); + List desc = + spark + .sql("explain select t1.* from test_dataset1 t1 join test_dataset3 t3 on t1.x = t3.x") + .collectAsList(); + assertEquals(1, desc.size()); + assertTrue(desc.get(0).getString(0).contains("BroadcastHashJoin")); + } }