Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(java): support overwrite for spark connector #3313

Merged
merged 5 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions java/core/lance-jni/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ lance-encoding = { workspace = true }
lance-linalg = { workspace = true }
lance-index = { workspace = true }
lance-io.workspace = true
lance-core.workspace = true
arrow = { workspace = true, features = ["ffi"] }
arrow-schema.workspace = true
datafusion.workspace = true
Expand Down
68 changes: 68 additions & 0 deletions java/core/lance-jni/src/blocking_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use lance::dataset::{ColumnAlteration, Dataset, ReadParams, WriteParams};
use lance::io::{ObjectStore, ObjectStoreParams};
use lance::table::format::Fragment;
use lance::table::format::Index;
use lance_core::datatypes::Schema as LanceSchema;
use lance_index::DatasetIndexExt;
use lance_index::{IndexParams, IndexType};
use lance_io::object_store::ObjectStoreRegistry;
Expand Down Expand Up @@ -393,6 +394,73 @@ pub fn inner_commit_append<'local>(
dataset.into_java(env)
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_Dataset_commitOverwrite<'local>(
mut env: JNIEnv<'local>,
_obj: JObject,
path: JString,
arrow_schema_addr: jlong,
read_version_obj: JObject, // Optional<Long>
fragments_obj: JObject, // List<String>, String is json serialized Fragment
storage_options_obj: JObject, // Map<String, String>
) -> JObject<'local> {
ok_or_throw!(
env,
inner_commit_overwrite(
&mut env,
path,
arrow_schema_addr,
read_version_obj,
fragments_obj,
storage_options_obj
)
)
}

pub fn inner_commit_overwrite<'local>(
env: &mut JNIEnv<'local>,
path: JString,
arrow_schema_addr: jlong,
read_version_obj: JObject, // Optional<Long>
fragments_obj: JObject, // List<String>, String is json serialized Fragment)
storage_options_obj: JObject, // Map<String, String>
) -> Result<JObject<'local>> {
let json_fragments = env.get_strings(&fragments_obj)?;
let mut fragments: Vec<Fragment> = Vec::new();
for json_fragment in json_fragments {
let fragment = Fragment::from_json(&json_fragment)?;
fragments.push(fragment);
}
let c_schema_ptr = arrow_schema_addr as *mut FFI_ArrowSchema;
let c_schema = unsafe { FFI_ArrowSchema::from_raw(c_schema_ptr) };
let arrow_schema = Schema::try_from(&c_schema)?;
let schema = LanceSchema::try_from(&arrow_schema)?;

let op = Operation::Overwrite {
fragments,
schema,
config_upsert_values: None,
};
let path_str = path.extract(env)?;
let read_version = env.get_u64_opt(&read_version_obj)?;
let jmap = JMap::from_env(env, &storage_options_obj)?;
let storage_options: HashMap<String, String> = env.with_local_frame(16, |env| {
let mut map = HashMap::new();
let mut iter = jmap.iter(env)?;
while let Some((key, value)) = iter.next(env)? {
let key_jstring = JString::from(key);
let value_jstring = JString::from(value);
let key_string: String = env.get_string(&key_jstring)?.into();
let value_string: String = env.get_string(&value_jstring)?.into();
map.insert(key_string, value_string);
}
Ok::<_, Error>(map)
})?;

let dataset = BlockingDataset::commit(&path_str, op, read_version, storage_options)?;
dataset.into_java(env)
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_Dataset_releaseNativeDataset(
mut env: JNIEnv,
Expand Down
7 changes: 7 additions & 0 deletions java/core/src/main/java/com/lancedb/lance/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,13 @@ public static native Dataset commitAppend(
List<String> fragmentsMetadata,
Map<String, String> storageOptions);

public static native Dataset commitOverwrite(
String path,
long arrowSchemaMemoryAddress,
Optional<Long> readVersion,
List<String> fragmentsMetadata,
Map<String, String> storageOptions);

/**
* Drop a Dataset.
*
Expand Down
35 changes: 35 additions & 0 deletions java/core/src/main/java/com/lancedb/lance/FragmentOperation.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@

package com.lancedb.lance;

import org.apache.arrow.c.ArrowSchema;
import org.apache.arrow.c.Data;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.types.pojo.Schema;

import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -61,4 +64,36 @@ public Dataset commit(
storageOptions);
}
}

/** Fragment overwrite operation. */
public static class Overwrite extends FragmentOperation {
private final List<FragmentMetadata> fragments;
private final Schema schema;

public Overwrite(List<FragmentMetadata> fragments, Schema schema) {
validateFragments(fragments);
this.fragments = fragments;
this.schema = schema;
}

@Override
public Dataset commit(
BufferAllocator allocator,
String path,
Optional<Long> readVersion,
Map<String, String> storageOptions) {
Preconditions.checkNotNull(allocator);
Preconditions.checkNotNull(path);
Preconditions.checkNotNull(readVersion);
try (ArrowSchema arrowSchema = ArrowSchema.allocateNew(allocator)) {
Data.exportSchema(allocator, schema, null, arrowSchema);
return Dataset.commitOverwrite(
path,
arrowSchema.memoryAddress(),
readVersion,
fragments.stream().map(FragmentMetadata::getJsonMetadata).collect(Collectors.toList()),
storageOptions);
}
}
}
}
47 changes: 47 additions & 0 deletions java/core/src/test/java/com/lancedb/lance/FragmentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;

Expand Down Expand Up @@ -119,6 +120,52 @@ void appendWithoutFragment() {
}
}

@Test
void testOverwriteCommit() throws Exception {
String datasetPath = tempDir.resolve("testOverwriteCommit").toString();
try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
TestUtils.SimpleTestDataset testDataset =
new TestUtils.SimpleTestDataset(allocator, datasetPath);
testDataset.createEmptyDataset().close();

// Commit fragment
int rowCount = 20;
FragmentMetadata fragmentMeta = testDataset.createNewFragment(rowCount);
FragmentOperation.Overwrite overwrite =
new FragmentOperation.Overwrite(
Collections.singletonList(fragmentMeta), testDataset.getSchema());
try (Dataset dataset = Dataset.commit(allocator, datasetPath, overwrite, Optional.of(1L))) {
assertEquals(2, dataset.version());
assertEquals(2, dataset.latestVersion());
assertEquals(rowCount, dataset.countRows());
DatasetFragment fragment = dataset.getFragments().get(0);

try (LanceScanner scanner = fragment.newScan()) {
Schema schemaRes = scanner.schema();
assertEquals(testDataset.getSchema(), schemaRes);
}
}

// Commit fragment again
rowCount = 40;
fragmentMeta = testDataset.createNewFragment(rowCount);
overwrite =
new FragmentOperation.Overwrite(
Collections.singletonList(fragmentMeta), testDataset.getSchema());
try (Dataset dataset = Dataset.commit(allocator, datasetPath, overwrite, Optional.of(2L))) {
assertEquals(3, dataset.version());
assertEquals(3, dataset.latestVersion());
assertEquals(rowCount, dataset.countRows());
DatasetFragment fragment = dataset.getFragments().get(0);

try (LanceScanner scanner = fragment.newScan()) {
Schema schemaRes = scanner.schema();
assertEquals(testDataset.getSchema(), schemaRes);
}
}
}
}

@Test
void testEmptyFragments() {
String datasetPath = tempDir.resolve("testEmptyFragments").toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
/** Lance Spark Dataset. */
public class LanceDataset implements SupportsRead, SupportsWrite, SupportsMetadataColumns {
private static final Set<TableCapability> CAPABILITIES =
ImmutableSet.of(TableCapability.BATCH_READ, TableCapability.BATCH_WRITE);
ImmutableSet.of(
TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.TRUNCATE);

public static final MetadataColumn[] METADATA_COLUMNS =
new MetadataColumn[] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,8 @@ public static int getBatchSize(LanceConfig config) {
public static boolean enableTopNPushDown(LanceConfig config) {
return Boolean.parseBoolean(config.getOptions().getOrDefault(topN_push_down, "true"));
}

public static boolean overwrite(LanceConfig config) {
return config.getOptions().getOrDefault(write_mode, "append").equalsIgnoreCase("overwrite");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.LanceArrowUtils;

Expand Down Expand Up @@ -76,7 +77,6 @@ public static void appendFragments(LanceConfig config, List<FragmentMetadata> fr
String uri = config.getDatasetUri();
ReadOptions options = SparkOptions.genReadOptionFromConfig(config);
try (Dataset datasetRead = Dataset.open(allocator, uri, options)) {

Dataset.commit(
allocator,
config.getDatasetUri(),
Expand All @@ -87,6 +87,23 @@ public static void appendFragments(LanceConfig config, List<FragmentMetadata> fr
}
}

public static void overwriteFragments(
LanceConfig config, List<FragmentMetadata> fragments, StructType sparkSchema) {
Schema schema = LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false);
FragmentOperation.Overwrite overwrite = new FragmentOperation.Overwrite(fragments, schema);
String uri = config.getDatasetUri();
ReadOptions options = SparkOptions.genReadOptionFromConfig(config);
try (Dataset datasetRead = Dataset.open(allocator, uri, options)) {
Dataset.commit(
allocator,
config.getDatasetUri(),
overwrite,
java.util.Optional.of(datasetRead.version()),
options.getStorageOptions())
.close();
}
}

public static LanceArrowWriter getArrowWriter(StructType sparkSchema, int batchSize) {
return new LanceArrowWriter(
allocator, LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false), batchSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ public static LanceFragmentScanner create(
LanceConfig config = inputPartition.getConfig();
ReadOptions options = SparkOptions.genReadOptionFromConfig(config);
dataset = Dataset.open(allocator, config.getDatasetUri(), options);
fragment = dataset.getFragments().get(fragmentId);
fragment =
dataset.getFragments().stream()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a O(n) operation? is it sensitive to the performance here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to use an O(1) operation to get the fragment, it has to build the dataset.getFragments() as a hash table and store it in the LanceInputPartition.

The LanceInputPartition will be serialized in spark and it will cause a lot of memory for a big lance dataset. So I think maybe the O(n) filter is a suitable way here.

.filter(f -> f.getId() == fragmentId)
.findAny()
.orElseThrow(() -> new RuntimeException("no fragment found for " + fragmentId));
ScanOptions.Builder scanOptions = new ScanOptions.Builder();
scanOptions.columns(getColumnNames(inputPartition.getSchema()));
if (inputPartition.getWhereCondition().isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import com.lancedb.lance.FragmentMetadata;
import com.lancedb.lance.spark.LanceConfig;
import com.lancedb.lance.spark.SparkOptions;
import com.lancedb.lance.spark.internal.LanceDatasetAdapter;

import org.apache.spark.sql.connector.write.BatchWrite;
Expand All @@ -28,13 +29,15 @@
import java.util.List;
import java.util.stream.Collectors;

public class BatchAppend implements BatchWrite {
public class LanceBatchWrite implements BatchWrite {
private final StructType schema;
private final LanceConfig config;
private final boolean overwrite;

public BatchAppend(StructType schema, LanceConfig config) {
public LanceBatchWrite(StructType schema, LanceConfig config, boolean overwrite) {
this.schema = schema;
this.config = config;
this.overwrite = overwrite;
}

@Override
Expand All @@ -55,7 +58,11 @@ public void commit(WriterCommitMessage[] messages) {
.map(TaskCommit::getFragments)
.flatMap(List::stream)
.collect(Collectors.toList());
LanceDatasetAdapter.appendFragments(config, fragments);
if (overwrite || SparkOptions.overwrite(this.config)) {
LanceDatasetAdapter.overwriteFragments(config, fragments, schema);
} else {
LanceDatasetAdapter.appendFragments(config, fragments);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public WriterCommitMessage commit() throws IOException {
arrowWriter.setFinished();
try {
List<FragmentMetadata> fragmentMetadata = fragmentCreationTask.get();
return new BatchAppend.TaskCommit(fragmentMetadata);
return new LanceBatchWrite.TaskCommit(fragmentMetadata);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IOException("Interrupted while waiting for reader thread to finish", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.lancedb.lance.spark.LanceConfig;

import org.apache.spark.sql.connector.write.BatchWrite;
import org.apache.spark.sql.connector.write.SupportsTruncate;
import org.apache.spark.sql.connector.write.Write;
import org.apache.spark.sql.connector.write.WriteBuilder;
import org.apache.spark.sql.connector.write.streaming.StreamingWrite;
Expand All @@ -26,15 +27,17 @@
public class SparkWrite implements Write {
private final LanceConfig config;
private final StructType schema;
private final boolean overwrite;

SparkWrite(StructType schema, LanceConfig config) {
SparkWrite(StructType schema, LanceConfig config, boolean overwrite) {
this.schema = schema;
this.config = config;
this.overwrite = overwrite;
}

@Override
public BatchWrite toBatch() {
return new BatchAppend(schema, config);
return new LanceBatchWrite(schema, config, overwrite);
}

@Override
Expand All @@ -43,9 +46,10 @@ public StreamingWrite toStreaming() {
}

/** Task commit. */
public static class SparkWriteBuilder implements WriteBuilder {
public static class SparkWriteBuilder implements SupportsTruncate, WriteBuilder {
private final LanceConfig config;
private final StructType schema;
private boolean overwrite = false;

public SparkWriteBuilder(StructType schema, LanceConfig config) {
this.schema = schema;
Expand All @@ -54,7 +58,13 @@ public SparkWriteBuilder(StructType schema, LanceConfig config) {

@Override
public Write build() {
return new SparkWrite(schema, config);
return new SparkWrite(schema, config, overwrite);
}

@Override
public WriteBuilder truncate() {
this.overwrite = true;
return this;
}
}
}
Loading
Loading