From 5649410691b6f80fadd802da22e92aa24c7915f4 Mon Sep 17 00:00:00 2001
From: huangzhaowei <huangzhaowei.416@bytedance.com>
Date: Sat, 4 Jan 2025 18:33:52 +0800
Subject: [PATCH 1/3] support get real data size for LanceStatistics

---
 java/core/lance-jni/src/blocking_dataset.rs   | 44 +++++++++++++++++++
 .../main/java/com/lancedb/lance/Dataset.java  | 20 +++++++++
 .../com/lancedb/lance/ipc/DataStatistics.java | 44 +++++++++++++++++++
 .../lancedb/lance/ipc/FieldStatistics.java    | 39 ++++++++++++++++
 .../java/com/lancedb/lance/DatasetTest.java   | 15 +++++++
 .../spark/internal/LanceDatasetAdapter.java   | 11 +++++
 .../lance/spark/read/LanceStatistics.java     | 11 ++---
 7 files changed, 177 insertions(+), 7 deletions(-)
 create mode 100644 java/core/src/main/java/com/lancedb/lance/ipc/DataStatistics.java
 create mode 100644 java/core/src/main/java/com/lancedb/lance/ipc/FieldStatistics.java

diff --git a/java/core/lance-jni/src/blocking_dataset.rs b/java/core/lance-jni/src/blocking_dataset.rs
index 94764751d1..7ef687c9cb 100644
--- a/java/core/lance-jni/src/blocking_dataset.rs
+++ b/java/core/lance-jni/src/blocking_dataset.rs
@@ -30,6 +30,7 @@ use jni::sys::{jboolean, jint};
 use jni::sys::{jbyteArray, jlong};
 use jni::{objects::JObject, JNIEnv};
 use lance::dataset::builder::DatasetBuilder;
+use lance::dataset::statistics::{DataStatistics, DatasetStatisticsExt};
 use lance::dataset::transaction::Operation;
 use lance::dataset::{ColumnAlteration, Dataset, ProjectionRequest, ReadParams, WriteParams};
 use lance::io::{ObjectStore, ObjectStoreParams};
@@ -154,6 +155,11 @@ impl BlockingDataset {
         Ok(rows)
     }
 
+    pub fn calculate_data_stats(&self) -> Result<DataStatistics> {
+        let stats = RT.block_on(Arc::new(self.clone().inner).calculate_data_stats())?;
+        Ok(stats)
+    }
+
     pub fn list_indexes(&self) -> Result<Arc<Vec<Index>>> {
         let indexes = RT.block_on(self.inner.load_indices())?;
         Ok(indexes)
@@ -725,6 +731,44 @@ fn inner_count_rows(
     dataset_guard.count_rows(filter)
 }
 
+#[no_mangle]
+pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeGetDataStatistics<'local>(
+    mut env: JNIEnv<'local>,
+    java_dataset: JObject,
+) -> JObject<'local> {
+    ok_or_throw!(env, inner_get_data_statistics(&mut env, java_dataset))
+}
+
+fn inner_get_data_statistics<'local>(
+    env: &mut JNIEnv<'local>,
+    java_dataset: JObject,
+) -> Result<JObject<'local>> {
+    let stats = {
+        let dataset_guard =
+            unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?;
+        let stats = dataset_guard.calculate_data_stats()?;
+        stats
+    };
+    let data_stats = env.new_object("com/lancedb/lance/ipc/DataStatistics", "()V", &[])?;
+
+    for field in stats.fields {
+        let id = field.id as jint;
+        let byte_size = field.bytes_on_disk as jlong;
+        let filed_jobj = env.new_object(
+            "com/lancedb/lance/ipc/FieldStatistics",
+            "(IJ)V",
+            &[JValue::Int(id), JValue::Long(byte_size)],
+        )?;
+        env.call_method(
+            &data_stats,
+            "addFiledStatistics",
+            "(Lcom/lancedb/lance/ipc/FieldStatistics;)V",
+            &[JValue::Object(&filed_jobj)],
+        )?;
+    }
+    Ok(data_stats)
+}
+
 #[no_mangle]
 pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeListIndexes<'local>(
     mut env: JNIEnv<'local>,
diff --git a/java/core/src/main/java/com/lancedb/lance/Dataset.java b/java/core/src/main/java/com/lancedb/lance/Dataset.java
index 88c945b71d..1a2baa43d6 100644
--- a/java/core/src/main/java/com/lancedb/lance/Dataset.java
+++ b/java/core/src/main/java/com/lancedb/lance/Dataset.java
@@ -15,6 +15,7 @@
 
 import com.lancedb.lance.index.IndexParams;
 import com.lancedb.lance.index.IndexType;
+import com.lancedb.lance.ipc.DataStatistics;
 import com.lancedb.lance.ipc.LanceScanner;
 import com.lancedb.lance.ipc.ScanOptions;
 import com.lancedb.lance.schema.ColumnAlteration;
@@ -450,6 +451,25 @@ public long countRows(String filter) {
 
   private native long nativeCountRows(Optional<String> filter);
 
+  /**
+   * Calculate the size of the dataset.
+   *
+   * @return the size of the dataset
+   */
+  public long calculateDataSize() {
+    try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) {
+      Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed");
+      return nativeGetDataStatistics().getDataSize();
+    }
+  }
+
+  /**
+   * Calculate the statistics of the dataset.
+   *
+   * @return the statistics of the dataset
+   */
+  private native DataStatistics nativeGetDataStatistics();
+
   /**
    * Get all fragments in this dataset.
    *
diff --git a/java/core/src/main/java/com/lancedb/lance/ipc/DataStatistics.java b/java/core/src/main/java/com/lancedb/lance/ipc/DataStatistics.java
new file mode 100644
index 0000000000..b7e7e2fcd9
--- /dev/null
+++ b/java/core/src/main/java/com/lancedb/lance/ipc/DataStatistics.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.lancedb.lance.ipc;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+public class DataStatistics implements Serializable {
+  private final List<FieldStatistics> fields;
+
+  public DataStatistics() {
+    this.fields = new ArrayList<>();
+  }
+
+  // used for rust to add field statistics
+  public void addFiledStatistics(FieldStatistics fieldStatistics) {
+    fields.add(fieldStatistics);
+  }
+
+  public List<FieldStatistics> getFields() {
+    return fields;
+  }
+
+  public long getDataSize() {
+    return fields.stream().mapToLong(FieldStatistics::getDataSize).sum();
+  }
+
+  @Override
+  public String toString() {
+    return "DataStatistics{" + "fields=" + fields + '}';
+  }
+}
diff --git a/java/core/src/main/java/com/lancedb/lance/ipc/FieldStatistics.java b/java/core/src/main/java/com/lancedb/lance/ipc/FieldStatistics.java
new file mode 100644
index 0000000000..8941201d53
--- /dev/null
+++ b/java/core/src/main/java/com/lancedb/lance/ipc/FieldStatistics.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.lancedb.lance.ipc;
+
+import java.io.Serializable;
+
+public class FieldStatistics implements Serializable {
+  private final int id;
+  private final long dataSize;
+
+  public FieldStatistics(int id, long dataSize) {
+    this.id = id;
+    this.dataSize = dataSize;
+  }
+
+  public int getId() {
+    return id;
+  }
+
+  public long getDataSize() {
+    return dataSize;
+  }
+
+  @Override
+  public String toString() {
+    return "FieldStatistics{" + "id=" + id + ", dataSize=" + dataSize + '}';
+  }
+}
diff --git a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java
index 25717d38b6..dc3dec04f8 100644
--- a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java
+++ b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java
@@ -358,4 +358,19 @@ void testCountRows() {
       }
     }
   }
+
+  @Test
+  void testCalculateDataSize() {
+    String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName();
+    String datasetPath = tempDir.resolve(testMethodName).toString();
+    try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
+      TestUtils.SimpleTestDataset testDataset =
+          new TestUtils.SimpleTestDataset(allocator, datasetPath);
+      dataset = testDataset.createEmptyDataset();
+
+      try (Dataset dataset2 = testDataset.write(1, 5)) {
+        assertEquals(100, dataset2.calculateDataSize());
+      }
+    }
+  }
 }
diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java
index c5fa24ac13..72b36a8aa3 100644
--- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java
+++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java
@@ -67,6 +67,17 @@ public static Optional<Long> getDatasetRowCount(LanceConfig config) {
     }
   }
 
+  public static Optional<Long> getDatasetDataSize(LanceConfig config) {
+    String uri = config.getDatasetUri();
+    ReadOptions options = SparkOptions.genReadOptionFromConfig(config);
+    try (Dataset dataset = Dataset.open(allocator, uri, options)) {
+      return Optional.of(dataset.calculateDataSize());
+    } catch (IllegalArgumentException e) {
+      // dataset not found
+      return Optional.empty();
+    }
+  }
+
   public static List<Integer> getFragmentIds(LanceConfig config) {
     String uri = config.getDatasetUri();
     ReadOptions options = SparkOptions.genReadOptionFromConfig(config);
diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceStatistics.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceStatistics.java
index 6300561d68..cb098caf42 100644
--- a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceStatistics.java
+++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceStatistics.java
@@ -18,25 +18,22 @@
 import com.lancedb.lance.spark.utils.Optional;
 
 import org.apache.spark.sql.connector.read.Statistics;
-import org.apache.spark.sql.types.StructType;
 
 import java.util.OptionalLong;
 
 public class LanceStatistics implements Statistics {
   private final Optional<Long> rowNumber;
-  private final Optional<StructType> schema;
+  private final Optional<Long> dataBytesSize;
 
   public LanceStatistics(LanceConfig config) {
     this.rowNumber = LanceDatasetAdapter.getDatasetRowCount(config);
-    this.schema = LanceDatasetAdapter.getSchema(config);
+    this.dataBytesSize = LanceDatasetAdapter.getDatasetDataSize(config);
   }
 
   @Override
   public OptionalLong sizeInBytes() {
-    // TODO: Support quickly get the bytes on disk for the lance dataset
-    // Now use schema to infer the byte size for simple
-    if (rowNumber.isPresent()) {
-      return OptionalLong.of(schema.get().defaultSize() * rowNumber.get());
+    if (dataBytesSize.isPresent()) {
+      return OptionalLong.of(dataBytesSize.get());
     } else {
       return OptionalLong.empty();
     }

From 70f07ffe81eb79d450c43cb2b4965255d19f3072 Mon Sep 17 00:00:00 2001
From: huangzhaowei <huangzhaowei.416@bytedance.com>
Date: Sat, 4 Jan 2025 19:31:13 +0800
Subject: [PATCH 2/3] fix rust format

---
 java/core/lance-jni/src/blocking_dataset.rs | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/java/core/lance-jni/src/blocking_dataset.rs b/java/core/lance-jni/src/blocking_dataset.rs
index 7ef687c9cb..15412edee0 100644
--- a/java/core/lance-jni/src/blocking_dataset.rs
+++ b/java/core/lance-jni/src/blocking_dataset.rs
@@ -746,8 +746,7 @@ fn inner_get_data_statistics<'local>(
     let stats = {
         let dataset_guard =
             unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?;
-        let stats = dataset_guard.calculate_data_stats()?;
-        stats
+        dataset_guard.calculate_data_stats()?
     };
     let data_stats = env.new_object("com/lancedb/lance/ipc/DataStatistics", "()V", &[])?;
 

From 813de88fdf4886a1a74ddf22a71cbdc2b5c7161d Mon Sep 17 00:00:00 2001
From: huangzhaowei <huangzhaowei.416@bytedance.com>
Date: Sun, 5 Jan 2025 10:53:47 +0800
Subject: [PATCH 3/3] add comment for dataSize method

---
 .../core/src/main/java/com/lancedb/lance/ipc/DataStatistics.java | 1 +
 .../src/main/java/com/lancedb/lance/ipc/FieldStatistics.java     | 1 +
 2 files changed, 2 insertions(+)

diff --git a/java/core/src/main/java/com/lancedb/lance/ipc/DataStatistics.java b/java/core/src/main/java/com/lancedb/lance/ipc/DataStatistics.java
index b7e7e2fcd9..fad3086f9f 100644
--- a/java/core/src/main/java/com/lancedb/lance/ipc/DataStatistics.java
+++ b/java/core/src/main/java/com/lancedb/lance/ipc/DataStatistics.java
@@ -33,6 +33,7 @@ public List<FieldStatistics> getFields() {
     return fields;
   }
 
+  // get total data size of the whole dataset in bytes
   public long getDataSize() {
     return fields.stream().mapToLong(FieldStatistics::getDataSize).sum();
   }
diff --git a/java/core/src/main/java/com/lancedb/lance/ipc/FieldStatistics.java b/java/core/src/main/java/com/lancedb/lance/ipc/FieldStatistics.java
index 8941201d53..34b83cd2d1 100644
--- a/java/core/src/main/java/com/lancedb/lance/ipc/FieldStatistics.java
+++ b/java/core/src/main/java/com/lancedb/lance/ipc/FieldStatistics.java
@@ -17,6 +17,7 @@
 
 public class FieldStatistics implements Serializable {
   private final int id;
+  // The size of the data in bytes
   private final long dataSize;
 
   public FieldStatistics(int id, long dataSize) {