diff --git a/java/core/lance-jni/src/fragment.rs b/java/core/lance-jni/src/fragment.rs index 66182b2d44..dacdd08798 100644 --- a/java/core/lance-jni/src/fragment.rs +++ b/java/core/lance-jni/src/fragment.rs @@ -29,7 +29,6 @@ use lance_datafusion::utils::StreamingWriteSource; use crate::error::{Error, Result}; use crate::{ blocking_dataset::{BlockingDataset, NATIVE_DATASET}, - ffi::JNIEnvExt, traits::FromJString, utils::extract_write_params, RT, @@ -77,7 +76,6 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiArray<'local dataset_uri: JString, arrow_array_addr: jlong, arrow_schema_addr: jlong, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional @@ -91,7 +89,6 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiArray<'local dataset_uri, arrow_array_addr, arrow_schema_addr, - fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, @@ -108,7 +105,6 @@ fn inner_create_with_ffi_array<'local>( dataset_uri: JString, arrow_array_addr: jlong, arrow_schema_addr: jlong, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional @@ -131,7 +127,6 @@ fn inner_create_with_ffi_array<'local>( create_fragment( env, dataset_uri, - fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, @@ -147,7 +142,6 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiStream<'a>( _obj: JObject, dataset_uri: JString, arrow_array_stream_addr: jlong, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional @@ -160,7 +154,6 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiStream<'a>( &mut env, dataset_uri, arrow_array_stream_addr, - fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, @@ -176,7 +169,6 @@ fn inner_create_with_ffi_stream<'local>( env: &mut JNIEnv<'local>, dataset_uri: JString, arrow_array_stream_addr: jlong, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional @@ -189,7 +181,6 @@ fn inner_create_with_ffi_stream<'local>( create_fragment( env, dataset_uri, - fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, @@ -203,7 +194,6 @@ fn inner_create_with_ffi_stream<'local>( fn create_fragment<'a>( env: &mut JNIEnv<'a>, dataset_uri: JString, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional @@ -213,8 +203,6 @@ fn create_fragment<'a>( ) -> Result> { let path_str = dataset_uri.extract(env)?; - let fragment_id_opts = env.get_int_opt(&fragment_id)?; - let write_params = extract_write_params( env, &max_rows_per_file, @@ -223,9 +211,8 @@ fn create_fragment<'a>( &mode, &storage_options_obj, )?; - let fragment = RT.block_on(FileFragment::create( + let fragment = RT.block_on(FileFragment::create_fragments( &path_str, - fragment_id_opts.unwrap_or(0) as usize, source, Some(write_params), ))?; diff --git a/java/core/lance-jni/src/utils.rs b/java/core/lance-jni/src/utils.rs index 6b15d4d58b..5f780de6c5 100644 --- a/java/core/lance-jni/src/utils.rs +++ b/java/core/lance-jni/src/utils.rs @@ -56,8 +56,8 @@ pub fn extract_write_params( if let Some(mode_val) = env.get_string_opt(mode)? { write_params.mode = WriteMode::try_from(mode_val.as_str())?; } - // Java code always sets the data storage version to Legacy for now - write_params.data_storage_version = Some(LanceFileVersion::Legacy); + // Java code always sets the data storage version to stable for now + write_params.data_storage_version = Some(LanceFileVersion::Stable); let jmap = JMap::from_env(env, storage_options_obj)?; let storage_options: HashMap = env.with_local_frame(16, |env| { let mut map = HashMap::new(); diff --git a/java/core/src/main/java/com/lancedb/lance/Fragment.java b/java/core/src/main/java/com/lancedb/lance/Fragment.java index db994a6e4a..fed5a95695 100644 --- a/java/core/src/main/java/com/lancedb/lance/Fragment.java +++ b/java/core/src/main/java/com/lancedb/lance/Fragment.java @@ -14,6 +14,7 @@ package com.lancedb.lance; +import java.util.List; import java.util.Map; import java.util.Optional; import org.apache.arrow.c.ArrowArray; @@ -36,24 +37,22 @@ public class Fragment { * @param datasetUri the dataset uri * @param allocator the buffer allocator * @param root the vector schema root - * @param fragmentId the fragment id * @param params the write params * @return the fragment metadata */ - public static FragmentMetadata create(String datasetUri, BufferAllocator allocator, - VectorSchemaRoot root, Optional fragmentId, WriteParams params) { + public static List create(String datasetUri, BufferAllocator allocator, + VectorSchemaRoot root, WriteParams params) { Preconditions.checkNotNull(datasetUri); Preconditions.checkNotNull(allocator); Preconditions.checkNotNull(root); - Preconditions.checkNotNull(fragmentId); Preconditions.checkNotNull(params); try (ArrowSchema arrowSchema = ArrowSchema.allocateNew(allocator); ArrowArray arrowArray = ArrowArray.allocateNew(allocator)) { Data.exportVectorSchemaRoot(allocator, root, null, arrowArray, arrowSchema); - return FragmentMetadata.fromJson(createWithFfiArray(datasetUri, arrowArray.memoryAddress(), - arrowSchema.memoryAddress(), fragmentId, params.getMaxRowsPerFile(), - params.getMaxRowsPerGroup(), params.getMaxBytesPerFile(), params.getMode(), - params.getStorageOptions())); + return FragmentMetadata.fromJsonArray(createWithFfiArray(datasetUri, + arrowArray.memoryAddress(), arrowSchema.memoryAddress(), + params.getMaxRowsPerFile(), params.getMaxRowsPerGroup(), params.getMaxBytesPerFile(), + params.getMode(), params.getStorageOptions())); } } @@ -61,18 +60,16 @@ public static FragmentMetadata create(String datasetUri, BufferAllocator allocat * Create a fragment from the given arrow stream. * @param datasetUri the dataset uri * @param stream the arrow stream - * @param fragmentId the fragment id * @param params the write params * @return the fragment metadata */ - public static FragmentMetadata create(String datasetUri, ArrowArrayStream stream, - Optional fragmentId, WriteParams params) { + public static List create(String datasetUri, ArrowArrayStream stream, + WriteParams params) { Preconditions.checkNotNull(datasetUri); Preconditions.checkNotNull(stream); - Preconditions.checkNotNull(fragmentId); Preconditions.checkNotNull(params); - return FragmentMetadata.fromJson(createWithFfiStream(datasetUri, - stream.memoryAddress(), fragmentId, + return FragmentMetadata.fromJsonArray(createWithFfiStream(datasetUri, + stream.memoryAddress(), params.getMaxRowsPerFile(), params.getMaxRowsPerGroup(), params.getMaxBytesPerFile(), params.getMode(), params.getStorageOptions())); } @@ -83,7 +80,7 @@ public static FragmentMetadata create(String datasetUri, ArrowArrayStream stream * @return the json serialized fragment metadata */ private static native String createWithFfiArray(String datasetUri, - long arrowArrayMemoryAddress, long arrowSchemaMemoryAddress, Optional fragmentId, + long arrowArrayMemoryAddress, long arrowSchemaMemoryAddress, Optional maxRowsPerFile, Optional maxRowsPerGroup, Optional maxBytesPerFile, Optional mode, Map storageOptions); @@ -93,7 +90,7 @@ private static native String createWithFfiArray(String datasetUri, * @return the json serialized fragment metadata */ private static native String createWithFfiStream(String datasetUri, long arrowStreamMemoryAddress, - Optional fragmentId, Optional maxRowsPerFile, + Optional maxRowsPerFile, Optional maxRowsPerGroup, Optional maxBytesPerFile, Optional mode, Map storageOptions); } diff --git a/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java b/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java index c2b5d665a2..c7f0f277cb 100644 --- a/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java +++ b/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java @@ -15,7 +15,11 @@ package com.lancedb.lance; import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + import org.apache.arrow.util.Preconditions; +import org.json.JSONArray; import org.json.JSONObject; import org.apache.commons.lang3.builder.ToStringBuilder; @@ -75,4 +79,27 @@ public static FragmentMetadata fromJson(String jsonMetadata) { return new FragmentMetadata(jsonMetadata, metadata.getInt(ID_KEY), metadata.getLong(PHYSICAL_ROWS_KEY)); } + + /** + * Converts a JSON array string into a list of FragmentMetadata objects. + * + * @param jsonMetadata A JSON array string containing fragment metadata. + * @return A list of FragmentMetadata objects. + */ + public static List fromJsonArray(String jsonMetadata) { + Preconditions.checkNotNull(jsonMetadata); + JSONArray metadatas = new JSONArray(jsonMetadata); + List fragmentMetadataList = new ArrayList<>(); + for (Object object : metadatas) { + JSONObject metadata = (JSONObject) object; + if (!metadata.has(ID_KEY) || !metadata.has(PHYSICAL_ROWS_KEY)) { + throw new IllegalArgumentException( + String.format("Fragment metadata must have {} and {} but is {}", + ID_KEY, PHYSICAL_ROWS_KEY, jsonMetadata)); + } + fragmentMetadataList.add(new FragmentMetadata(metadata.toString(), metadata.getInt(ID_KEY), + metadata.getLong(PHYSICAL_ROWS_KEY))); + } + return fragmentMetadataList; + } } diff --git a/java/core/src/test/java/com/lancedb/lance/FragmentTest.java b/java/core/src/test/java/com/lancedb/lance/FragmentTest.java index a9fbe6c017..c7de20fc99 100644 --- a/java/core/src/test/java/com/lancedb/lance/FragmentTest.java +++ b/java/core/src/test/java/com/lancedb/lance/FragmentTest.java @@ -21,7 +21,7 @@ import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; +import java.util.List; import java.util.Optional; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.types.pojo.Schema; @@ -37,7 +37,7 @@ void testFragmentCreateFfiArray() { try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - testDataset.createNewFragment(123, 20); + testDataset.createNewFragment(20); } } @@ -47,9 +47,8 @@ void testFragmentCreate() throws Exception { try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - int fragmentId = 312; int rowCount = 21; - FragmentMetadata fragmentMeta = testDataset.createNewFragment(fragmentId, rowCount); + FragmentMetadata fragmentMeta = testDataset.createNewFragment(rowCount); // Commit fragment FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(fragmentMeta)); @@ -58,8 +57,7 @@ void testFragmentCreate() throws Exception { assertEquals(2, dataset.latestVersion()); assertEquals(rowCount, dataset.countRows()); DatasetFragment fragment = dataset.getFragments().get(0); - assertEquals(fragmentId, fragment.getId()); - + try (LanceScanner scanner = fragment.newScan()) { Schema schemaRes = scanner.schema(); assertEquals(testDataset.getSchema(), schemaRes); @@ -74,7 +72,7 @@ void commitWithoutVersion() { try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - FragmentMetadata meta = testDataset.createNewFragment(123, 20); + FragmentMetadata meta = testDataset.createNewFragment(20); FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(meta)); assertThrows(IllegalArgumentException.class, () -> { Dataset.commit(allocator, datasetPath, appendOp, Optional.empty()); @@ -88,7 +86,7 @@ void commitOldVersion() { try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - FragmentMetadata meta = testDataset.createNewFragment(123, 20); + FragmentMetadata meta = testDataset.createNewFragment(20); FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(meta)); assertThrows(IllegalArgumentException.class, () -> { Dataset.commit(allocator, datasetPath, appendOp, Optional.of(0L)); @@ -107,4 +105,26 @@ void appendWithoutFragment() { }); } } + + @Test + void testEmptyFragments() { + String datasetPath = tempDir.resolve("testEmptyFragments").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + List fragments = testDataset.createNewFragment(0, 10); + assertEquals(0, fragments.size()); + } + } + + @Test + void testMultiFragments() { + String datasetPath = tempDir.resolve("testMultiFragments").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + List fragments = testDataset.createNewFragment(20, 10); + assertEquals(2, fragments.size()); + } + } } diff --git a/java/core/src/test/java/com/lancedb/lance/ScannerTest.java b/java/core/src/test/java/com/lancedb/lance/ScannerTest.java index fc46a95c52..11d55a087d 100644 --- a/java/core/src/test/java/com/lancedb/lance/ScannerTest.java +++ b/java/core/src/test/java/com/lancedb/lance/ScannerTest.java @@ -225,17 +225,16 @@ void testScanFragment() throws Exception { try (BufferAllocator allocator = new RootAllocator()) { TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - int[] fragment0 = new int[]{0, 3}; - int[] fragment1 = new int[]{1, 5}; - int[] fragment2 = new int[]{2, 7}; - FragmentMetadata metadata0 = testDataset.createNewFragment(fragment0[0], fragment0[1]); - FragmentMetadata metadata1 = testDataset.createNewFragment(fragment1[0], fragment1[1]); - FragmentMetadata metadata2 = testDataset.createNewFragment(fragment2[0], fragment2[1]); + FragmentMetadata metadata0 = testDataset.createNewFragment(3); + FragmentMetadata metadata1 = testDataset.createNewFragment(5); + FragmentMetadata metadata2 = testDataset.createNewFragment(7); FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(metadata0, metadata1, metadata2)); try (Dataset dataset = Dataset.commit(allocator, datasetPath, appendOp, Optional.of(1L))) { - validScanResult(dataset, fragment0[0], fragment0[1]); - validScanResult(dataset, fragment1[0], fragment1[1]); - validScanResult(dataset, fragment2[0], fragment2[1]); + List frags = dataset.getFragments(); + assertEquals(3, frags.size()); + validScanResult(dataset, frags.get(0).getId(), 3); + validScanResult(dataset, frags.get(1).getId(), 5); + validScanResult(dataset, frags.get(2).getId(), 7); } } } @@ -246,15 +245,14 @@ void testScanFragments() throws Exception { try (BufferAllocator allocator = new RootAllocator()) { TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - int[] fragment0 = new int[]{0, 3}; - int[] fragment1 = new int[]{1, 5}; - int[] fragment2 = new int[]{2, 7}; - FragmentMetadata metadata0 = testDataset.createNewFragment(fragment0[0], fragment0[1]); - FragmentMetadata metadata1 = testDataset.createNewFragment(fragment1[0], fragment1[1]); - FragmentMetadata metadata2 = testDataset.createNewFragment(fragment2[0], fragment2[1]); + FragmentMetadata metadata0 = testDataset.createNewFragment(3); + FragmentMetadata metadata1 = testDataset.createNewFragment(5); + FragmentMetadata metadata2 = testDataset.createNewFragment(7); FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(metadata0, metadata1, metadata2)); try (Dataset dataset = Dataset.commit(allocator, datasetPath, appendOp, Optional.of(1L))) { - try (Scanner scanner = dataset.newScan(new ScanOptions.Builder().batchSize(1024).fragmentIds(Arrays.asList(1, 2)).build())) { + List frags = dataset.getFragments(); + assertEquals(3, frags.size()); + try (Scanner scanner = dataset.newScan(new ScanOptions.Builder().batchSize(1024).fragmentIds(Arrays.asList(frags.get(1).getId(), frags.get(2).getId())).build())) { try (ArrowReader reader = scanner.scanBatches()) { assertEquals(dataset.getSchema().getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); int rowcount = 0; diff --git a/java/core/src/test/java/com/lancedb/lance/TestUtils.java b/java/core/src/test/java/com/lancedb/lance/TestUtils.java index 461adc4767..259f8aac18 100644 --- a/java/core/src/test/java/com/lancedb/lance/TestUtils.java +++ b/java/core/src/test/java/com/lancedb/lance/TestUtils.java @@ -76,8 +76,16 @@ public Dataset createEmptyDataset() { return dataset; } - public FragmentMetadata createNewFragment(int fragmentId, int rowCount) { - FragmentMetadata fragmentMeta; + public FragmentMetadata createNewFragment(int rowCount) { + List fragmentMetas = createNewFragment(rowCount, Integer.MAX_VALUE); + assertEquals(1, fragmentMetas.size()); + FragmentMetadata fragmentMeta = fragmentMetas.get(0); + assertEquals(rowCount, fragmentMeta.getPhysicalRows()); + return fragmentMeta; + } + + public List createNewFragment(int rowCount, int maxRowsPerFile) { + List fragmentMetas; try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { root.allocateNew(); IntVector idVector = (IntVector) root.getVector("id"); @@ -90,16 +98,14 @@ public FragmentMetadata createNewFragment(int fragmentId, int rowCount) { } root.setRowCount(rowCount); - fragmentMeta = Fragment.create(datasetPath, - allocator, root, Optional.of(fragmentId), new WriteParams.Builder().build()); - assertEquals(fragmentId, fragmentMeta.getId()); - assertEquals(rowCount, fragmentMeta.getPhysicalRows()); + fragmentMetas = Fragment.create(datasetPath, + allocator, root, new WriteParams.Builder().withMaxRowsPerFile(maxRowsPerFile).build()); } - return fragmentMeta; + return fragmentMetas; } public Dataset write(long version, int rowCount) { - FragmentMetadata metadata = createNewFragment(rowCount, rowCount); + FragmentMetadata metadata = createNewFragment(rowCount); FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(metadata)); return Dataset.commit(allocator, datasetPath, appendOp, Optional.of(version)); } diff --git a/java/core/src/test/java/com/lancedb/lance/TestVectorDataset.java b/java/core/src/test/java/com/lancedb/lance/TestVectorDataset.java index 564d47dd25..f2747eec68 100644 --- a/java/core/src/test/java/com/lancedb/lance/TestVectorDataset.java +++ b/java/core/src/test/java/com/lancedb/lance/TestVectorDataset.java @@ -102,7 +102,7 @@ private FragmentMetadata createFragment(int batchIndex) throws IOException { root.setRowCount(80); WriteParams fragmentWriteParams = new WriteParams.Builder().build(); - return Fragment.create(datasetPath.toString(), allocator, root, Optional.of(batchIndex), fragmentWriteParams); + return Fragment.create(datasetPath.toString(), allocator, root, fragmentWriteParams).get(0); } } @@ -127,8 +127,8 @@ public Dataset appendNewData() throws IOException { root.setRowCount(10); WriteParams writeParams = new WriteParams.Builder().build(); - fragmentMetadata = Fragment.create(datasetPath.toString(), allocator, root, Optional.empty(), - writeParams); + fragmentMetadata = Fragment.create(datasetPath.toString(), allocator, root, + writeParams).get(0); } FragmentOperation.Append appendOp = new FragmentOperation.Append(Collections.singletonList(fragmentMetadata)); return Dataset.commit(allocator, datasetPath.toString(), appendOp, Optional.of(2L)); diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java index ff87744b6c..d674dfc4e6 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java @@ -33,9 +33,7 @@ import java.util.stream.Collectors; public class LanceDatasetAdapter { - private static final BufferAllocator allocator = new RootAllocator( - RootAllocator.configBuilder().from(RootAllocator.defaultConfig()) - .maxAllocation(64 * 1024 * 1024).build()); + private static final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); public static Optional getSchema(LanceConfig config) { String uri = config.getDatasetUri(); @@ -88,12 +86,12 @@ public static LanceArrowWriter getArrowWriter(StructType sparkSchema, int batchS ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false), batchSize); } - public static FragmentMetadata createFragment(String datasetUri, ArrowReader reader, + public static List createFragment(String datasetUri, ArrowReader reader, WriteParams params) { try (ArrowArrayStream arrowStream = ArrowArrayStream.allocateNew(allocator)) { Data.exportArrayStream(allocator, reader, arrowStream); return Fragment.create(datasetUri, arrowStream, - java.util.Optional.empty(), params); + params); } } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java index 706b6144d1..02dcf630c2 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java @@ -26,18 +26,18 @@ import org.apache.spark.sql.types.StructType; import java.io.IOException; -import java.util.Arrays; +import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.FutureTask; public class LanceDataWriter implements DataWriter { private LanceArrowWriter arrowWriter; - private FutureTask fragmentCreationTask; + private FutureTask> fragmentCreationTask; private Thread fragmentCreationThread; private LanceDataWriter(LanceArrowWriter arrowWriter, - FutureTask fragmentCreationTask, Thread fragmentCreationThread) { + FutureTask> fragmentCreationTask, Thread fragmentCreationThread) { // TODO support write to multiple fragments this.arrowWriter = arrowWriter; this.fragmentCreationThread = fragmentCreationThread; @@ -53,8 +53,8 @@ public void write(InternalRow record) throws IOException { public WriterCommitMessage commit() throws IOException { arrowWriter.setFinished(); try { - FragmentMetadata fragmentMetadata = fragmentCreationTask.get(); - return new BatchAppend.TaskCommit(Arrays.asList(fragmentMetadata)); + List fragmentMetadata = fragmentCreationTask.get(); + return new BatchAppend.TaskCommit(fragmentMetadata); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IOException("Interrupted while waiting for reader thread to finish", e); @@ -93,9 +93,9 @@ protected WriterFactory(StructType schema, LanceConfig config) { public DataWriter createWriter(int partitionId, long taskId) { LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(schema, 1024); WriteParams params = SparkOptions.genWriteParamsFromConfig(config); - Callable fragmentCreator + Callable> fragmentCreator = () -> LanceDatasetAdapter.createFragment(config.getDatasetUri(), arrowWriter, params); - FutureTask fragmentCreationTask = new FutureTask<>(fragmentCreator); + FutureTask> fragmentCreationTask = new FutureTask<>(fragmentCreator); Thread fragmentCreationThread = new Thread(fragmentCreationTask); fragmentCreationThread.start(); diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java index 78c5f9cb12..bc846bdb54 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java @@ -32,7 +32,9 @@ import org.junit.jupiter.api.TestInfo; import org.junit.jupiter.api.io.TempDir; +import java.io.File; import java.nio.file.Path; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -52,6 +54,7 @@ static void setup() { .appName("spark-lance-connector-test") .master("local") .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog") + .config("spark.sql.catalog.lance.max_row_per_file", "1") .getOrCreate(); StructType schema = new StructType(new StructField[]{ DataTypes.createStructField("id", DataTypes.IntegerType, false), @@ -144,6 +147,31 @@ public void overwrite(TestInfo testInfo) { validateData(datasetName, 1); } + @Test + public void writeMultiFiles(TestInfo testInfo) { + String datasetName = testInfo.getTestMethod().get().getName(); + String filePath = LanceConfig.getDatasetUri(dbPath.toString(), datasetName); + testData.write().format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, filePath) + .save(); + + validateData(datasetName, 1); + File directory = new File(filePath + "/data"); + assertEquals(2, directory.listFiles().length); + } + + @Test + public void writeEmptyTaskFiles(TestInfo testInfo) { + String datasetName = testInfo.getTestMethod().get().getName(); + String filePath = LanceConfig.getDatasetUri(dbPath.toString(), datasetName); + testData.repartition(4).write().format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, filePath) + .save(); + + File directory = new File(filePath + "/data"); + assertEquals(2, directory.listFiles().length); + } + private void validateData(String datasetName, int iteration) { Dataset data = spark.read().format("lance") .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index b3cf1a10e4..d1d1d790ad 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -531,6 +531,21 @@ impl FileFragment { builder.write(source, Some(id as u64)).await } + /// Create a list of [`FileFragment`] from a [`StreamingWriteSource`]. + pub async fn create_fragments( + dataset_uri: &str, + source: impl StreamingWriteSource, + params: Option, + ) -> Result> { + let mut builder = FragmentCreateBuilder::new(dataset_uri); + + if let Some(params) = params.as_ref() { + builder = builder.write_params(params); + } + + builder.write_fragments(source).await + } + pub async fn create_from_file( filename: &str, dataset: &Dataset, diff --git a/rust/lance/src/dataset/fragment/write.rs b/rust/lance/src/dataset/fragment/write.rs index 83e0fe8e21..1d9d5cb5a9 100644 --- a/rust/lance/src/dataset/fragment/write.rs +++ b/rust/lance/src/dataset/fragment/write.rs @@ -1,8 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::borrow::Cow; - use arrow_schema::Schema as ArrowSchema; use datafusion::execution::SendableRecordBatchStream; use futures::{StreamExt, TryStreamExt}; @@ -17,9 +15,12 @@ use lance_io::object_store::ObjectStore; use lance_table::format::{DataFile, Fragment}; use lance_table::io::manifest::ManifestDescribing; use snafu::{location, Location}; +use std::borrow::Cow; +use std::sync::Arc; use uuid::Uuid; use crate::dataset::builder::DatasetBuilder; +use crate::dataset::write::do_write_fragments; use crate::dataset::{WriteMode, WriteParams, DATA_DIR}; use crate::Result; @@ -68,6 +69,15 @@ impl<'a> FragmentCreateBuilder<'a> { self.write_impl(stream, schema, id).await } + /// Write multi fragment which separated by max_rows_per_file. + pub async fn write_fragments( + &self, + source: impl StreamingWriteSource, + ) -> Result> { + let (stream, schema) = self.get_stream_and_schema(Box::new(source)).await?; + self.write_fragments_v2_impl(stream, schema).await + } + async fn write_v2_impl( &self, stream: SendableRecordBatchStream, @@ -136,6 +146,31 @@ impl<'a> FragmentCreateBuilder<'a> { Ok(fragment) } + async fn write_fragments_v2_impl( + &self, + stream: SendableRecordBatchStream, + schema: Schema, + ) -> Result> { + let params = self.write_params.map(Cow::Borrowed).unwrap_or_default(); + + Self::validate_schema(&schema, stream.schema().as_ref())?; + + let (object_store, base_path) = ObjectStore::from_uri_and_params( + params.object_store_registry.clone(), + self.dataset_uri, + ¶ms.store_params.clone().unwrap_or_default(), + ) + .await?; + do_write_fragments( + Arc::new(object_store), + &base_path, + &schema, + stream, + params.into_owned(), + LanceFileVersion::Stable, + ) + .await + } async fn write_impl( &self, @@ -353,4 +388,93 @@ mod tests { assert_eq!(fragment.files[0].fields, vec![3, 1]); assert_eq!(fragment.files[0].column_indices, vec![0, 1]); } + + #[tokio::test] + async fn test_write_fragments_validation() { + // Writing with empty schema produces an error + let empty_schema = Arc::new(ArrowSchema::empty()); + let empty_reader = Box::new(RecordBatchIterator::new(vec![], empty_schema)); + let tmp_dir = tempfile::tempdir().unwrap(); + let result = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .write_fragments(empty_reader) + .await; + assert!(result.is_err()); + assert!( + matches!(result.as_ref().unwrap_err(), Error::InvalidInput { source, .. } + if source.to_string().contains("Cannot write with an empty schema.")), + "{:?}", + &result + ); + + // Writing empty reader produces an error + let arrow_schema = test_data().schema(); + let empty_reader = Box::new(RecordBatchIterator::new(vec![], arrow_schema.clone())); + let result = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .write_fragments(empty_reader) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 0); + + // Writing with incorrect schema produces an error. + let wrong_schema = arrow_schema + .as_ref() + .try_with_column(ArrowField::new("c", DataType::Utf8, false)) + .unwrap(); + let wrong_schema = Schema::try_from(&wrong_schema).unwrap(); + let result = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .schema(&wrong_schema) + .write_fragments(test_data()) + .await; + assert!(result.is_err()); + assert!( + matches!(result.as_ref().unwrap_err(), Error::SchemaMismatch { difference, .. } + if difference.contains("fields did not match")), + "{:?}", + &result + ); + } + + #[tokio::test] + async fn test_write_fragments_default_schema() { + // Infers schema and uses 0 as default field id + let data = test_data(); + let tmp_dir = tempfile::tempdir().unwrap(); + let fragments = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .write_fragments(data) + .await + .unwrap(); + + // If unspecified, the fragment id should be 0. + assert_eq!(fragments.len(), 1); + assert_eq!(fragments[0].deletion_file, None); + assert_eq!(fragments[0].files.len(), 1); + assert_eq!(fragments[0].files[0].fields, vec![0, 1]); + } + + #[tokio::test] + async fn test_write_fragments_with_options() { + // Uses provided schema. Field ids are correct in fragment metadata. + let data = test_data(); + let tmp_dir = tempfile::tempdir().unwrap(); + let writer_params = WriteParams { + max_rows_per_file: 1, + ..Default::default() + }; + let fragments = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .write_params(&writer_params) + .write_fragments(data) + .await + .unwrap(); + + assert_eq!(fragments.len(), 3); + assert_eq!(fragments[0].deletion_file, None); + assert_eq!(fragments[0].files.len(), 1); + assert_eq!(fragments[0].files[0].column_indices, vec![0, 1]); + assert_eq!(fragments[1].deletion_file, None); + assert_eq!(fragments[1].files.len(), 1); + assert_eq!(fragments[1].files[0].column_indices, vec![0, 1]); + assert_eq!(fragments[2].deletion_file, None); + assert_eq!(fragments[2].files.len(), 1); + assert_eq!(fragments[2].files[0].column_indices, vec![0, 1]); + } }