From 39222ec930623b2d27a5636ad068a8378cc0ae35 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Tue, 3 Dec 2024 04:20:13 +0800 Subject: [PATCH] feat: support write multi fragments or empty fragment in one spark task (#3183) Now `FileFragment::create` only support create one file fragment and in spark connector will cause these two issues: 1. if the spark task is empty, this api will have exception since there is no data to create the fragment. 2. if the task data stream is very large, it will generate a huge file in lance format. It is not friendly for spark parallism. So I remove the assigned fragment id and add a new method named `FileFragment::create_fragments` to generate empty or multi fragments. ![image](https://github.com/user-attachments/assets/54fb2497-8163-4652-9e0b-d50a88fade53) --- java/core/lance-jni/src/fragment.rs | 15 +- java/core/lance-jni/src/utils.rs | 4 +- .../main/java/com/lancedb/lance/Fragment.java | 29 ++-- .../com/lancedb/lance/FragmentMetadata.java | 27 ++++ .../java/com/lancedb/lance/FragmentTest.java | 36 +++-- .../java/com/lancedb/lance/ScannerTest.java | 30 ++-- .../java/com/lancedb/lance/TestUtils.java | 22 +-- .../com/lancedb/lance/TestVectorDataset.java | 6 +- .../spark/internal/LanceDatasetAdapter.java | 8 +- .../lance/spark/write/LanceDataWriter.java | 14 +- .../lance/spark/write/SparkWriteTest.java | 28 ++++ rust/lance/src/dataset/fragment.rs | 15 ++ rust/lance/src/dataset/fragment/write.rs | 128 +++++++++++++++++- 13 files changed, 281 insertions(+), 81 deletions(-) diff --git a/java/core/lance-jni/src/fragment.rs b/java/core/lance-jni/src/fragment.rs index 66182b2d44..dacdd08798 100644 --- a/java/core/lance-jni/src/fragment.rs +++ b/java/core/lance-jni/src/fragment.rs @@ -29,7 +29,6 @@ use lance_datafusion::utils::StreamingWriteSource; use crate::error::{Error, Result}; use crate::{ blocking_dataset::{BlockingDataset, NATIVE_DATASET}, - ffi::JNIEnvExt, traits::FromJString, utils::extract_write_params, RT, @@ -77,7 +76,6 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiArray<'local dataset_uri: JString, arrow_array_addr: jlong, arrow_schema_addr: jlong, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional @@ -91,7 +89,6 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiArray<'local dataset_uri, arrow_array_addr, arrow_schema_addr, - fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, @@ -108,7 +105,6 @@ fn inner_create_with_ffi_array<'local>( dataset_uri: JString, arrow_array_addr: jlong, arrow_schema_addr: jlong, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional @@ -131,7 +127,6 @@ fn inner_create_with_ffi_array<'local>( create_fragment( env, dataset_uri, - fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, @@ -147,7 +142,6 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiStream<'a>( _obj: JObject, dataset_uri: JString, arrow_array_stream_addr: jlong, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional @@ -160,7 +154,6 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiStream<'a>( &mut env, dataset_uri, arrow_array_stream_addr, - fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, @@ -176,7 +169,6 @@ fn inner_create_with_ffi_stream<'local>( env: &mut JNIEnv<'local>, dataset_uri: JString, arrow_array_stream_addr: jlong, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional @@ -189,7 +181,6 @@ fn inner_create_with_ffi_stream<'local>( create_fragment( env, dataset_uri, - fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, @@ -203,7 +194,6 @@ fn inner_create_with_ffi_stream<'local>( fn create_fragment<'a>( env: &mut JNIEnv<'a>, dataset_uri: JString, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional @@ -213,8 +203,6 @@ fn create_fragment<'a>( ) -> Result> { let path_str = dataset_uri.extract(env)?; - let fragment_id_opts = env.get_int_opt(&fragment_id)?; - let write_params = extract_write_params( env, &max_rows_per_file, @@ -223,9 +211,8 @@ fn create_fragment<'a>( &mode, &storage_options_obj, )?; - let fragment = RT.block_on(FileFragment::create( + let fragment = RT.block_on(FileFragment::create_fragments( &path_str, - fragment_id_opts.unwrap_or(0) as usize, source, Some(write_params), ))?; diff --git a/java/core/lance-jni/src/utils.rs b/java/core/lance-jni/src/utils.rs index 6b15d4d58b..5f780de6c5 100644 --- a/java/core/lance-jni/src/utils.rs +++ b/java/core/lance-jni/src/utils.rs @@ -56,8 +56,8 @@ pub fn extract_write_params( if let Some(mode_val) = env.get_string_opt(mode)? { write_params.mode = WriteMode::try_from(mode_val.as_str())?; } - // Java code always sets the data storage version to Legacy for now - write_params.data_storage_version = Some(LanceFileVersion::Legacy); + // Java code always sets the data storage version to stable for now + write_params.data_storage_version = Some(LanceFileVersion::Stable); let jmap = JMap::from_env(env, storage_options_obj)?; let storage_options: HashMap = env.with_local_frame(16, |env| { let mut map = HashMap::new(); diff --git a/java/core/src/main/java/com/lancedb/lance/Fragment.java b/java/core/src/main/java/com/lancedb/lance/Fragment.java index db994a6e4a..fed5a95695 100644 --- a/java/core/src/main/java/com/lancedb/lance/Fragment.java +++ b/java/core/src/main/java/com/lancedb/lance/Fragment.java @@ -14,6 +14,7 @@ package com.lancedb.lance; +import java.util.List; import java.util.Map; import java.util.Optional; import org.apache.arrow.c.ArrowArray; @@ -36,24 +37,22 @@ public class Fragment { * @param datasetUri the dataset uri * @param allocator the buffer allocator * @param root the vector schema root - * @param fragmentId the fragment id * @param params the write params * @return the fragment metadata */ - public static FragmentMetadata create(String datasetUri, BufferAllocator allocator, - VectorSchemaRoot root, Optional fragmentId, WriteParams params) { + public static List create(String datasetUri, BufferAllocator allocator, + VectorSchemaRoot root, WriteParams params) { Preconditions.checkNotNull(datasetUri); Preconditions.checkNotNull(allocator); Preconditions.checkNotNull(root); - Preconditions.checkNotNull(fragmentId); Preconditions.checkNotNull(params); try (ArrowSchema arrowSchema = ArrowSchema.allocateNew(allocator); ArrowArray arrowArray = ArrowArray.allocateNew(allocator)) { Data.exportVectorSchemaRoot(allocator, root, null, arrowArray, arrowSchema); - return FragmentMetadata.fromJson(createWithFfiArray(datasetUri, arrowArray.memoryAddress(), - arrowSchema.memoryAddress(), fragmentId, params.getMaxRowsPerFile(), - params.getMaxRowsPerGroup(), params.getMaxBytesPerFile(), params.getMode(), - params.getStorageOptions())); + return FragmentMetadata.fromJsonArray(createWithFfiArray(datasetUri, + arrowArray.memoryAddress(), arrowSchema.memoryAddress(), + params.getMaxRowsPerFile(), params.getMaxRowsPerGroup(), params.getMaxBytesPerFile(), + params.getMode(), params.getStorageOptions())); } } @@ -61,18 +60,16 @@ public static FragmentMetadata create(String datasetUri, BufferAllocator allocat * Create a fragment from the given arrow stream. * @param datasetUri the dataset uri * @param stream the arrow stream - * @param fragmentId the fragment id * @param params the write params * @return the fragment metadata */ - public static FragmentMetadata create(String datasetUri, ArrowArrayStream stream, - Optional fragmentId, WriteParams params) { + public static List create(String datasetUri, ArrowArrayStream stream, + WriteParams params) { Preconditions.checkNotNull(datasetUri); Preconditions.checkNotNull(stream); - Preconditions.checkNotNull(fragmentId); Preconditions.checkNotNull(params); - return FragmentMetadata.fromJson(createWithFfiStream(datasetUri, - stream.memoryAddress(), fragmentId, + return FragmentMetadata.fromJsonArray(createWithFfiStream(datasetUri, + stream.memoryAddress(), params.getMaxRowsPerFile(), params.getMaxRowsPerGroup(), params.getMaxBytesPerFile(), params.getMode(), params.getStorageOptions())); } @@ -83,7 +80,7 @@ public static FragmentMetadata create(String datasetUri, ArrowArrayStream stream * @return the json serialized fragment metadata */ private static native String createWithFfiArray(String datasetUri, - long arrowArrayMemoryAddress, long arrowSchemaMemoryAddress, Optional fragmentId, + long arrowArrayMemoryAddress, long arrowSchemaMemoryAddress, Optional maxRowsPerFile, Optional maxRowsPerGroup, Optional maxBytesPerFile, Optional mode, Map storageOptions); @@ -93,7 +90,7 @@ private static native String createWithFfiArray(String datasetUri, * @return the json serialized fragment metadata */ private static native String createWithFfiStream(String datasetUri, long arrowStreamMemoryAddress, - Optional fragmentId, Optional maxRowsPerFile, + Optional maxRowsPerFile, Optional maxRowsPerGroup, Optional maxBytesPerFile, Optional mode, Map storageOptions); } diff --git a/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java b/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java index c2b5d665a2..c7f0f277cb 100644 --- a/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java +++ b/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java @@ -15,7 +15,11 @@ package com.lancedb.lance; import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + import org.apache.arrow.util.Preconditions; +import org.json.JSONArray; import org.json.JSONObject; import org.apache.commons.lang3.builder.ToStringBuilder; @@ -75,4 +79,27 @@ public static FragmentMetadata fromJson(String jsonMetadata) { return new FragmentMetadata(jsonMetadata, metadata.getInt(ID_KEY), metadata.getLong(PHYSICAL_ROWS_KEY)); } + + /** + * Converts a JSON array string into a list of FragmentMetadata objects. + * + * @param jsonMetadata A JSON array string containing fragment metadata. + * @return A list of FragmentMetadata objects. + */ + public static List fromJsonArray(String jsonMetadata) { + Preconditions.checkNotNull(jsonMetadata); + JSONArray metadatas = new JSONArray(jsonMetadata); + List fragmentMetadataList = new ArrayList<>(); + for (Object object : metadatas) { + JSONObject metadata = (JSONObject) object; + if (!metadata.has(ID_KEY) || !metadata.has(PHYSICAL_ROWS_KEY)) { + throw new IllegalArgumentException( + String.format("Fragment metadata must have {} and {} but is {}", + ID_KEY, PHYSICAL_ROWS_KEY, jsonMetadata)); + } + fragmentMetadataList.add(new FragmentMetadata(metadata.toString(), metadata.getInt(ID_KEY), + metadata.getLong(PHYSICAL_ROWS_KEY))); + } + return fragmentMetadataList; + } } diff --git a/java/core/src/test/java/com/lancedb/lance/FragmentTest.java b/java/core/src/test/java/com/lancedb/lance/FragmentTest.java index a9fbe6c017..c7de20fc99 100644 --- a/java/core/src/test/java/com/lancedb/lance/FragmentTest.java +++ b/java/core/src/test/java/com/lancedb/lance/FragmentTest.java @@ -21,7 +21,7 @@ import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; +import java.util.List; import java.util.Optional; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.types.pojo.Schema; @@ -37,7 +37,7 @@ void testFragmentCreateFfiArray() { try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - testDataset.createNewFragment(123, 20); + testDataset.createNewFragment(20); } } @@ -47,9 +47,8 @@ void testFragmentCreate() throws Exception { try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - int fragmentId = 312; int rowCount = 21; - FragmentMetadata fragmentMeta = testDataset.createNewFragment(fragmentId, rowCount); + FragmentMetadata fragmentMeta = testDataset.createNewFragment(rowCount); // Commit fragment FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(fragmentMeta)); @@ -58,8 +57,7 @@ void testFragmentCreate() throws Exception { assertEquals(2, dataset.latestVersion()); assertEquals(rowCount, dataset.countRows()); DatasetFragment fragment = dataset.getFragments().get(0); - assertEquals(fragmentId, fragment.getId()); - + try (LanceScanner scanner = fragment.newScan()) { Schema schemaRes = scanner.schema(); assertEquals(testDataset.getSchema(), schemaRes); @@ -74,7 +72,7 @@ void commitWithoutVersion() { try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - FragmentMetadata meta = testDataset.createNewFragment(123, 20); + FragmentMetadata meta = testDataset.createNewFragment(20); FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(meta)); assertThrows(IllegalArgumentException.class, () -> { Dataset.commit(allocator, datasetPath, appendOp, Optional.empty()); @@ -88,7 +86,7 @@ void commitOldVersion() { try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - FragmentMetadata meta = testDataset.createNewFragment(123, 20); + FragmentMetadata meta = testDataset.createNewFragment(20); FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(meta)); assertThrows(IllegalArgumentException.class, () -> { Dataset.commit(allocator, datasetPath, appendOp, Optional.of(0L)); @@ -107,4 +105,26 @@ void appendWithoutFragment() { }); } } + + @Test + void testEmptyFragments() { + String datasetPath = tempDir.resolve("testEmptyFragments").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + List fragments = testDataset.createNewFragment(0, 10); + assertEquals(0, fragments.size()); + } + } + + @Test + void testMultiFragments() { + String datasetPath = tempDir.resolve("testMultiFragments").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + List fragments = testDataset.createNewFragment(20, 10); + assertEquals(2, fragments.size()); + } + } } diff --git a/java/core/src/test/java/com/lancedb/lance/ScannerTest.java b/java/core/src/test/java/com/lancedb/lance/ScannerTest.java index fc46a95c52..11d55a087d 100644 --- a/java/core/src/test/java/com/lancedb/lance/ScannerTest.java +++ b/java/core/src/test/java/com/lancedb/lance/ScannerTest.java @@ -225,17 +225,16 @@ void testScanFragment() throws Exception { try (BufferAllocator allocator = new RootAllocator()) { TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - int[] fragment0 = new int[]{0, 3}; - int[] fragment1 = new int[]{1, 5}; - int[] fragment2 = new int[]{2, 7}; - FragmentMetadata metadata0 = testDataset.createNewFragment(fragment0[0], fragment0[1]); - FragmentMetadata metadata1 = testDataset.createNewFragment(fragment1[0], fragment1[1]); - FragmentMetadata metadata2 = testDataset.createNewFragment(fragment2[0], fragment2[1]); + FragmentMetadata metadata0 = testDataset.createNewFragment(3); + FragmentMetadata metadata1 = testDataset.createNewFragment(5); + FragmentMetadata metadata2 = testDataset.createNewFragment(7); FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(metadata0, metadata1, metadata2)); try (Dataset dataset = Dataset.commit(allocator, datasetPath, appendOp, Optional.of(1L))) { - validScanResult(dataset, fragment0[0], fragment0[1]); - validScanResult(dataset, fragment1[0], fragment1[1]); - validScanResult(dataset, fragment2[0], fragment2[1]); + List frags = dataset.getFragments(); + assertEquals(3, frags.size()); + validScanResult(dataset, frags.get(0).getId(), 3); + validScanResult(dataset, frags.get(1).getId(), 5); + validScanResult(dataset, frags.get(2).getId(), 7); } } } @@ -246,15 +245,14 @@ void testScanFragments() throws Exception { try (BufferAllocator allocator = new RootAllocator()) { TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - int[] fragment0 = new int[]{0, 3}; - int[] fragment1 = new int[]{1, 5}; - int[] fragment2 = new int[]{2, 7}; - FragmentMetadata metadata0 = testDataset.createNewFragment(fragment0[0], fragment0[1]); - FragmentMetadata metadata1 = testDataset.createNewFragment(fragment1[0], fragment1[1]); - FragmentMetadata metadata2 = testDataset.createNewFragment(fragment2[0], fragment2[1]); + FragmentMetadata metadata0 = testDataset.createNewFragment(3); + FragmentMetadata metadata1 = testDataset.createNewFragment(5); + FragmentMetadata metadata2 = testDataset.createNewFragment(7); FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(metadata0, metadata1, metadata2)); try (Dataset dataset = Dataset.commit(allocator, datasetPath, appendOp, Optional.of(1L))) { - try (Scanner scanner = dataset.newScan(new ScanOptions.Builder().batchSize(1024).fragmentIds(Arrays.asList(1, 2)).build())) { + List frags = dataset.getFragments(); + assertEquals(3, frags.size()); + try (Scanner scanner = dataset.newScan(new ScanOptions.Builder().batchSize(1024).fragmentIds(Arrays.asList(frags.get(1).getId(), frags.get(2).getId())).build())) { try (ArrowReader reader = scanner.scanBatches()) { assertEquals(dataset.getSchema().getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); int rowcount = 0; diff --git a/java/core/src/test/java/com/lancedb/lance/TestUtils.java b/java/core/src/test/java/com/lancedb/lance/TestUtils.java index 461adc4767..259f8aac18 100644 --- a/java/core/src/test/java/com/lancedb/lance/TestUtils.java +++ b/java/core/src/test/java/com/lancedb/lance/TestUtils.java @@ -76,8 +76,16 @@ public Dataset createEmptyDataset() { return dataset; } - public FragmentMetadata createNewFragment(int fragmentId, int rowCount) { - FragmentMetadata fragmentMeta; + public FragmentMetadata createNewFragment(int rowCount) { + List fragmentMetas = createNewFragment(rowCount, Integer.MAX_VALUE); + assertEquals(1, fragmentMetas.size()); + FragmentMetadata fragmentMeta = fragmentMetas.get(0); + assertEquals(rowCount, fragmentMeta.getPhysicalRows()); + return fragmentMeta; + } + + public List createNewFragment(int rowCount, int maxRowsPerFile) { + List fragmentMetas; try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { root.allocateNew(); IntVector idVector = (IntVector) root.getVector("id"); @@ -90,16 +98,14 @@ public FragmentMetadata createNewFragment(int fragmentId, int rowCount) { } root.setRowCount(rowCount); - fragmentMeta = Fragment.create(datasetPath, - allocator, root, Optional.of(fragmentId), new WriteParams.Builder().build()); - assertEquals(fragmentId, fragmentMeta.getId()); - assertEquals(rowCount, fragmentMeta.getPhysicalRows()); + fragmentMetas = Fragment.create(datasetPath, + allocator, root, new WriteParams.Builder().withMaxRowsPerFile(maxRowsPerFile).build()); } - return fragmentMeta; + return fragmentMetas; } public Dataset write(long version, int rowCount) { - FragmentMetadata metadata = createNewFragment(rowCount, rowCount); + FragmentMetadata metadata = createNewFragment(rowCount); FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(metadata)); return Dataset.commit(allocator, datasetPath, appendOp, Optional.of(version)); } diff --git a/java/core/src/test/java/com/lancedb/lance/TestVectorDataset.java b/java/core/src/test/java/com/lancedb/lance/TestVectorDataset.java index 564d47dd25..f2747eec68 100644 --- a/java/core/src/test/java/com/lancedb/lance/TestVectorDataset.java +++ b/java/core/src/test/java/com/lancedb/lance/TestVectorDataset.java @@ -102,7 +102,7 @@ private FragmentMetadata createFragment(int batchIndex) throws IOException { root.setRowCount(80); WriteParams fragmentWriteParams = new WriteParams.Builder().build(); - return Fragment.create(datasetPath.toString(), allocator, root, Optional.of(batchIndex), fragmentWriteParams); + return Fragment.create(datasetPath.toString(), allocator, root, fragmentWriteParams).get(0); } } @@ -127,8 +127,8 @@ public Dataset appendNewData() throws IOException { root.setRowCount(10); WriteParams writeParams = new WriteParams.Builder().build(); - fragmentMetadata = Fragment.create(datasetPath.toString(), allocator, root, Optional.empty(), - writeParams); + fragmentMetadata = Fragment.create(datasetPath.toString(), allocator, root, + writeParams).get(0); } FragmentOperation.Append appendOp = new FragmentOperation.Append(Collections.singletonList(fragmentMetadata)); return Dataset.commit(allocator, datasetPath.toString(), appendOp, Optional.of(2L)); diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java index ff87744b6c..d674dfc4e6 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java @@ -33,9 +33,7 @@ import java.util.stream.Collectors; public class LanceDatasetAdapter { - private static final BufferAllocator allocator = new RootAllocator( - RootAllocator.configBuilder().from(RootAllocator.defaultConfig()) - .maxAllocation(64 * 1024 * 1024).build()); + private static final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); public static Optional getSchema(LanceConfig config) { String uri = config.getDatasetUri(); @@ -88,12 +86,12 @@ public static LanceArrowWriter getArrowWriter(StructType sparkSchema, int batchS ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false), batchSize); } - public static FragmentMetadata createFragment(String datasetUri, ArrowReader reader, + public static List createFragment(String datasetUri, ArrowReader reader, WriteParams params) { try (ArrowArrayStream arrowStream = ArrowArrayStream.allocateNew(allocator)) { Data.exportArrayStream(allocator, reader, arrowStream); return Fragment.create(datasetUri, arrowStream, - java.util.Optional.empty(), params); + params); } } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java index 706b6144d1..02dcf630c2 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java @@ -26,18 +26,18 @@ import org.apache.spark.sql.types.StructType; import java.io.IOException; -import java.util.Arrays; +import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.FutureTask; public class LanceDataWriter implements DataWriter { private LanceArrowWriter arrowWriter; - private FutureTask fragmentCreationTask; + private FutureTask> fragmentCreationTask; private Thread fragmentCreationThread; private LanceDataWriter(LanceArrowWriter arrowWriter, - FutureTask fragmentCreationTask, Thread fragmentCreationThread) { + FutureTask> fragmentCreationTask, Thread fragmentCreationThread) { // TODO support write to multiple fragments this.arrowWriter = arrowWriter; this.fragmentCreationThread = fragmentCreationThread; @@ -53,8 +53,8 @@ public void write(InternalRow record) throws IOException { public WriterCommitMessage commit() throws IOException { arrowWriter.setFinished(); try { - FragmentMetadata fragmentMetadata = fragmentCreationTask.get(); - return new BatchAppend.TaskCommit(Arrays.asList(fragmentMetadata)); + List fragmentMetadata = fragmentCreationTask.get(); + return new BatchAppend.TaskCommit(fragmentMetadata); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IOException("Interrupted while waiting for reader thread to finish", e); @@ -93,9 +93,9 @@ protected WriterFactory(StructType schema, LanceConfig config) { public DataWriter createWriter(int partitionId, long taskId) { LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(schema, 1024); WriteParams params = SparkOptions.genWriteParamsFromConfig(config); - Callable fragmentCreator + Callable> fragmentCreator = () -> LanceDatasetAdapter.createFragment(config.getDatasetUri(), arrowWriter, params); - FutureTask fragmentCreationTask = new FutureTask<>(fragmentCreator); + FutureTask> fragmentCreationTask = new FutureTask<>(fragmentCreator); Thread fragmentCreationThread = new Thread(fragmentCreationTask); fragmentCreationThread.start(); diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java index 78c5f9cb12..bc846bdb54 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java @@ -32,7 +32,9 @@ import org.junit.jupiter.api.TestInfo; import org.junit.jupiter.api.io.TempDir; +import java.io.File; import java.nio.file.Path; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -52,6 +54,7 @@ static void setup() { .appName("spark-lance-connector-test") .master("local") .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog") + .config("spark.sql.catalog.lance.max_row_per_file", "1") .getOrCreate(); StructType schema = new StructType(new StructField[]{ DataTypes.createStructField("id", DataTypes.IntegerType, false), @@ -144,6 +147,31 @@ public void overwrite(TestInfo testInfo) { validateData(datasetName, 1); } + @Test + public void writeMultiFiles(TestInfo testInfo) { + String datasetName = testInfo.getTestMethod().get().getName(); + String filePath = LanceConfig.getDatasetUri(dbPath.toString(), datasetName); + testData.write().format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, filePath) + .save(); + + validateData(datasetName, 1); + File directory = new File(filePath + "/data"); + assertEquals(2, directory.listFiles().length); + } + + @Test + public void writeEmptyTaskFiles(TestInfo testInfo) { + String datasetName = testInfo.getTestMethod().get().getName(); + String filePath = LanceConfig.getDatasetUri(dbPath.toString(), datasetName); + testData.repartition(4).write().format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, filePath) + .save(); + + File directory = new File(filePath + "/data"); + assertEquals(2, directory.listFiles().length); + } + private void validateData(String datasetName, int iteration) { Dataset data = spark.read().format("lance") .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index b3cf1a10e4..d1d1d790ad 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -531,6 +531,21 @@ impl FileFragment { builder.write(source, Some(id as u64)).await } + /// Create a list of [`FileFragment`] from a [`StreamingWriteSource`]. + pub async fn create_fragments( + dataset_uri: &str, + source: impl StreamingWriteSource, + params: Option, + ) -> Result> { + let mut builder = FragmentCreateBuilder::new(dataset_uri); + + if let Some(params) = params.as_ref() { + builder = builder.write_params(params); + } + + builder.write_fragments(source).await + } + pub async fn create_from_file( filename: &str, dataset: &Dataset, diff --git a/rust/lance/src/dataset/fragment/write.rs b/rust/lance/src/dataset/fragment/write.rs index 83e0fe8e21..1d9d5cb5a9 100644 --- a/rust/lance/src/dataset/fragment/write.rs +++ b/rust/lance/src/dataset/fragment/write.rs @@ -1,8 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::borrow::Cow; - use arrow_schema::Schema as ArrowSchema; use datafusion::execution::SendableRecordBatchStream; use futures::{StreamExt, TryStreamExt}; @@ -17,9 +15,12 @@ use lance_io::object_store::ObjectStore; use lance_table::format::{DataFile, Fragment}; use lance_table::io::manifest::ManifestDescribing; use snafu::{location, Location}; +use std::borrow::Cow; +use std::sync::Arc; use uuid::Uuid; use crate::dataset::builder::DatasetBuilder; +use crate::dataset::write::do_write_fragments; use crate::dataset::{WriteMode, WriteParams, DATA_DIR}; use crate::Result; @@ -68,6 +69,15 @@ impl<'a> FragmentCreateBuilder<'a> { self.write_impl(stream, schema, id).await } + /// Write multi fragment which separated by max_rows_per_file. + pub async fn write_fragments( + &self, + source: impl StreamingWriteSource, + ) -> Result> { + let (stream, schema) = self.get_stream_and_schema(Box::new(source)).await?; + self.write_fragments_v2_impl(stream, schema).await + } + async fn write_v2_impl( &self, stream: SendableRecordBatchStream, @@ -136,6 +146,31 @@ impl<'a> FragmentCreateBuilder<'a> { Ok(fragment) } + async fn write_fragments_v2_impl( + &self, + stream: SendableRecordBatchStream, + schema: Schema, + ) -> Result> { + let params = self.write_params.map(Cow::Borrowed).unwrap_or_default(); + + Self::validate_schema(&schema, stream.schema().as_ref())?; + + let (object_store, base_path) = ObjectStore::from_uri_and_params( + params.object_store_registry.clone(), + self.dataset_uri, + ¶ms.store_params.clone().unwrap_or_default(), + ) + .await?; + do_write_fragments( + Arc::new(object_store), + &base_path, + &schema, + stream, + params.into_owned(), + LanceFileVersion::Stable, + ) + .await + } async fn write_impl( &self, @@ -353,4 +388,93 @@ mod tests { assert_eq!(fragment.files[0].fields, vec![3, 1]); assert_eq!(fragment.files[0].column_indices, vec![0, 1]); } + + #[tokio::test] + async fn test_write_fragments_validation() { + // Writing with empty schema produces an error + let empty_schema = Arc::new(ArrowSchema::empty()); + let empty_reader = Box::new(RecordBatchIterator::new(vec![], empty_schema)); + let tmp_dir = tempfile::tempdir().unwrap(); + let result = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .write_fragments(empty_reader) + .await; + assert!(result.is_err()); + assert!( + matches!(result.as_ref().unwrap_err(), Error::InvalidInput { source, .. } + if source.to_string().contains("Cannot write with an empty schema.")), + "{:?}", + &result + ); + + // Writing empty reader produces an error + let arrow_schema = test_data().schema(); + let empty_reader = Box::new(RecordBatchIterator::new(vec![], arrow_schema.clone())); + let result = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .write_fragments(empty_reader) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 0); + + // Writing with incorrect schema produces an error. + let wrong_schema = arrow_schema + .as_ref() + .try_with_column(ArrowField::new("c", DataType::Utf8, false)) + .unwrap(); + let wrong_schema = Schema::try_from(&wrong_schema).unwrap(); + let result = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .schema(&wrong_schema) + .write_fragments(test_data()) + .await; + assert!(result.is_err()); + assert!( + matches!(result.as_ref().unwrap_err(), Error::SchemaMismatch { difference, .. } + if difference.contains("fields did not match")), + "{:?}", + &result + ); + } + + #[tokio::test] + async fn test_write_fragments_default_schema() { + // Infers schema and uses 0 as default field id + let data = test_data(); + let tmp_dir = tempfile::tempdir().unwrap(); + let fragments = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .write_fragments(data) + .await + .unwrap(); + + // If unspecified, the fragment id should be 0. + assert_eq!(fragments.len(), 1); + assert_eq!(fragments[0].deletion_file, None); + assert_eq!(fragments[0].files.len(), 1); + assert_eq!(fragments[0].files[0].fields, vec![0, 1]); + } + + #[tokio::test] + async fn test_write_fragments_with_options() { + // Uses provided schema. Field ids are correct in fragment metadata. + let data = test_data(); + let tmp_dir = tempfile::tempdir().unwrap(); + let writer_params = WriteParams { + max_rows_per_file: 1, + ..Default::default() + }; + let fragments = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .write_params(&writer_params) + .write_fragments(data) + .await + .unwrap(); + + assert_eq!(fragments.len(), 3); + assert_eq!(fragments[0].deletion_file, None); + assert_eq!(fragments[0].files.len(), 1); + assert_eq!(fragments[0].files[0].column_indices, vec![0, 1]); + assert_eq!(fragments[1].deletion_file, None); + assert_eq!(fragments[1].files.len(), 1); + assert_eq!(fragments[1].files[0].column_indices, vec![0, 1]); + assert_eq!(fragments[2].deletion_file, None); + assert_eq!(fragments[2].files.len(), 1); + assert_eq!(fragments[2].files[0].column_indices, vec![0, 1]); + } }