From 8c46445dc077322a801179f7d2ecd6a1dc16ce77 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Fri, 8 Sep 2023 11:10:37 -0700 Subject: [PATCH] feat: update the commit function to work with transactions (#1193) Previously the commit function did not use the new transactions feature. This meant that a commit running concurrently with another operation could lead to bad behavior. The commit operation was also limited mostly to overwrites. Now commits can apply any kind of transaction operation. --------- Co-authored-by: Will Jones --- python/python/lance/__init__.py | 9 +- python/python/lance/dataset.py | 144 +++++++++++++++++++++++---- python/python/lance/fragment.py | 2 +- python/python/tests/test_dataset.py | 123 ++++++++++++++++++++++- python/python/tests/test_fragment.py | 6 +- python/src/dataset.rs | 143 +++++++++++++++++++++----- python/src/lib.rs | 3 +- rust/lance/src/dataset.rs | 103 ++++++++++++++----- rust/lance/src/dataset/fragment.rs | 30 +++--- 9 files changed, 468 insertions(+), 95 deletions(-) diff --git a/python/python/lance/__init__.py b/python/python/lance/__init__.py index 7d113bcb31..73448ba9c9 100644 --- a/python/python/lance/__init__.py +++ b/python/python/lance/__init__.py @@ -26,13 +26,20 @@ pd = None ts_types = Union[datetime, str] -from .dataset import LanceDataset, LanceScanner, __version__, write_dataset +from .dataset import ( + LanceDataset, + LanceOperation, + LanceScanner, + __version__, + write_dataset, +) from .fragment import FragmentMetadata, LanceFragment from .schema import json_to_schema, schema_to_json from .util import sanitize_ts __all__ = [ "LanceDataset", + "LanceOperation", "LanceScanner", "__version__", "write_dataset", diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index d589253d6e..c6f6f6d76a 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -17,6 +17,8 @@ import json import os +from abc import ABC, abstractmethod +from dataclasses import dataclass from datetime import datetime, timedelta from functools import lru_cache from pathlib import Path @@ -31,7 +33,7 @@ from .commit import CommitLock from .fragment import FragmentMetadata, LanceFragment from .lance import __version__ as __version__ -from .lance import _Dataset, _Scanner, _write_dataset +from .lance import _Dataset, _Operation, _Scanner, _write_dataset try: import pandas as pd @@ -680,21 +682,37 @@ def create_index( @staticmethod def _commit( base_uri: Union[str, Path], - new_schema: pa.Schema, - fragments: Iterable[FragmentMetadata], - mode: str = "append", + operation: LanceOperation.BaseOperation, + read_version: Optional[int] = None, + commit_lock: Optional[CommitLock] = None, ) -> LanceDataset: - """Create a new version of dataset with collected fragments. + """Create a new version of dataset + + This method is an advanced method which allows users to describe a change + that has been made to the data files. This method is not needed when using + Lance to apply changes (e.g. when using :py:class:`LanceDataset` or + :py:func:`write_dataset`.) + + It's current purpose is to allow for changes being made in a distributed + environment where no single process is doing all of the work. For example, + a distributed bulk update or a distributed bulk modify operation. - This method allows users to commit a version of dataset in a distributed - environment. + Once all of the changes have been made, this method can be called to make + the changes visible by updating the dataset manifest. Parameters ---------- - new_schema : pa.Schema - The schema for the new version of dataset. - fragments : list[FragmentMetadata] - The fragments to create new version of dataset. + base_uri: str or Path + The base uri of the dataset + operation: BaseOperation + The operation to apply to the dataset. This describes what changes + have been made. + read_version: int, optional + The version of the dataset that was used as the base for the changes. + This is not needed for overwrite or restore operations. + commit_lock : CommitLock, optional + A custom commit lock. Only needed if your object store does not support + atomic commits. See the user guide for more details. Returns ------- @@ -708,8 +726,21 @@ def _commit( """ if isinstance(base_uri, Path): base_uri = str(base_uri) - if not isinstance(new_schema, pa.Schema): - raise TypeError(f"schema must be pyarrow.Schema, got {type(new_schema)}") + + if commit_lock: + if not callable(commit_lock): + raise TypeError( + f"commit_lock must be a function, got {type(commit_lock)}" + ) + + _Dataset.commit(base_uri, operation._to_inner(), read_version, commit_lock) + return LanceDataset(base_uri) + + +# LanceOperation is a namespace for operations that can be applied to a dataset. +class LanceOperation: + @staticmethod + def _validate_fragments(fragments): if not isinstance(fragments, list): raise TypeError( f"fragments must be list[FragmentMetadata], got {type(fragments)}" @@ -720,10 +751,86 @@ def _commit( raise TypeError( f"fragments must be list[FragmentMetadata], got {type(fragments[0])}" ) - raw_fragments = [f._metadata for f in fragments] - # TODO: make fragments as a generator - _Dataset.commit(base_uri, new_schema, raw_fragments) - return LanceDataset(base_uri) + + class BaseOperation(ABC): + @abstractmethod + def _to_inner(self): + raise NotImplementedError() + + @dataclass + class Overwrite(BaseOperation): + new_schema: pa.Schema + fragments: Iterable[FragmentMetadata] + + def __post_init__(self): + if not isinstance(self.new_schema, pa.Schema): + raise TypeError( + f"schema must be pyarrow.Schema, got {type(self.new_schema)}" + ) + LanceOperation._validate_fragments(self.fragments) + + def _to_inner(self): + raw_fragments = [f._metadata for f in self.fragments] + return _Operation.overwrite(self.new_schema, raw_fragments) + + @dataclass + class Append(BaseOperation): + fragments: Iterable[FragmentMetadata] + + def __post_init__(self): + LanceOperation._validate_fragments(self.fragments) + + def _to_inner(self): + raw_fragments = [f._metadata for f in self.fragments] + return _Operation.append(raw_fragments) + + @dataclass + class Delete(BaseOperation): + updated_fragments: Iterable[FragmentMetadata] + deleted_fragment_ids: Iterable[int] + predicate: str + + def __post_init__(self): + LanceOperation._validate_fragments(self.updated_fragments) + + def _to_inner(self): + raw_updated_fragments = [f._metadata for f in self.updated_fragments] + return _Operation.delete( + raw_updated_fragments, self.deleted_fragment_ids, self.predicate + ) + + @dataclass + class Rewrite(BaseOperation): + old_fragments: Iterable[FragmentMetadata] + new_fragments: Iterable[FragmentMetadata] + + def __post_init__(self): + LanceOperation._validate_fragments(self.old_fragments) + LanceOperation._validate_fragments(self.new_fragments) + + def _to_inner(self): + raw_old_fragments = [f._metadata for f in self.old_fragments] + raw_new_fragments = [f._metadata for f in self.new_fragments] + return _Operation.rewrite(raw_old_fragments, raw_new_fragments) + + @dataclass + class Merge(BaseOperation): + fragments: Iterable[FragmentMetadata] + schema: pa.Schema + + def __post_init__(self): + LanceOperation._validate_fragments(self.fragments) + + def _to_inner(self): + raw_fragments = [f._metadata for f in self.fragments] + return _Operation.merge(raw_fragments, self.schema) + + @dataclass + class Restore(BaseOperation): + version: int + + def _to_inner(self): + return _Operation.restore(self.version) class ScannerBuilder: @@ -991,6 +1098,9 @@ def write_dataset( The max number of rows to write before starting a new file max_rows_per_group: int, default 1024 The max number of rows before starting a new group (in the same file) + commit_lock : CommitLock, optional + A custom commit lock. Only needed if your object store does not support + atomic commits. See the user guide for more details. """ reader = _coerce_reader(data_obj, schema) diff --git a/python/python/lance/fragment.py b/python/python/lance/fragment.py index 328332f639..4de842ae6c 100644 --- a/python/python/lance/fragment.py +++ b/python/python/lance/fragment.py @@ -347,4 +347,4 @@ def metadata(self) -> FragmentMetadata: FragmentMetadata """ - return FragmentMetadata(self._fragment.metadata()) + return FragmentMetadata(self._fragment.metadata().json()) diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index aab6a77b98..d86dbff223 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -362,7 +362,8 @@ def adder(batch: pa.RecordBatch) -> pa.RecordBatch: fragment_metadata = fragment.add_columns(adder, columns=["a"]) schema = dataset.schema.append(pa.field("c", pa.int64())) - dataset = lance.LanceDataset._commit(base_dir, schema, [fragment_metadata]) + operation = lance.LanceOperation.Overwrite(schema, [fragment_metadata]) + dataset = lance.LanceDataset._commit(base_dir, operation) assert dataset.schema == schema tbl = dataset.to_table() @@ -371,16 +372,126 @@ def adder(batch: pa.RecordBatch) -> pa.RecordBatch: ) -def test_create_from_fragments(tmp_path: Path): +def test_create_from_commit(tmp_path: Path): table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) base_dir = tmp_path / "test" fragment = lance.fragment.LanceFragment.create(base_dir, table) - dataset = lance.LanceDataset._commit(base_dir, table.schema, [fragment]) + operation = lance.LanceOperation.Overwrite(table.schema, [fragment]) + dataset = lance.LanceDataset._commit(base_dir, operation) tbl = dataset.to_table() assert tbl == table +def test_append_with_commit(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) + base_dir = tmp_path / "test" + + lance.write_dataset(table, base_dir) + + fragment = lance.fragment.LanceFragment.create(base_dir, table) + append = lance.LanceOperation.Append([fragment]) + + with pytest.raises(OSError): + # Must specify read version + dataset = lance.LanceDataset._commit(base_dir, append) + + dataset = lance.LanceDataset._commit(base_dir, append, read_version=1) + + tbl = dataset.to_table() + + expected = pa.Table.from_pydict( + { + "a": list(range(100)) + list(range(100)), + "b": list(range(100)) + list(range(100)), + } + ) + assert tbl == expected + + +def test_delete_with_commit(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) + base_dir = tmp_path / "test" + + lance.write_dataset(table, base_dir) + lance.write_dataset(table, base_dir, mode="append") + + half_table = pa.Table.from_pydict({"a": range(50), "b": range(50)}) + + fragments = lance.dataset(base_dir).get_fragments() + + updated_fragment = fragments[0].delete("a >= 50") + delete = lance.LanceOperation.Delete( + [updated_fragment], [fragments[1].fragment_id], "hello" + ) + + dataset = lance.LanceDataset._commit(base_dir, delete, read_version=2) + + tbl = dataset.to_table() + assert tbl == half_table + + +def test_rewrite_with_commit(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) + base_dir = tmp_path / "test" + + lance.write_dataset(table, base_dir) + lance.write_dataset(table, base_dir, mode="append") + + combined = pa.Table.from_pydict( + { + "a": list(range(100)) + list(range(100)), + "b": list(range(100)) + list(range(100)), + } + ) + + to_be_rewrote = [lf.metadata for lf in lance.dataset(base_dir).get_fragments()] + + fragment = lance.fragment.LanceFragment.create(base_dir, combined) + rewrite = lance.LanceOperation.Rewrite(to_be_rewrote, [fragment]) + + dataset = lance.LanceDataset._commit(base_dir, rewrite, read_version=1) + + tbl = dataset.to_table() + assert tbl == combined + + assert len(lance.dataset(base_dir).get_fragments()) == 1 + + +def test_restore_with_commit(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) + base_dir = tmp_path / "test" + + lance.write_dataset(table, base_dir) + lance.write_dataset(table, base_dir, mode="append") + + restore = lance.LanceOperation.Restore(1) + dataset = lance.LanceDataset._commit(base_dir, restore) + + tbl = dataset.to_table() + assert tbl == table + + +def test_merge_with_commit(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) + base_dir = tmp_path / "test" + + lance.write_dataset(table, base_dir) + + fragment = lance.dataset(base_dir).get_fragments()[0] + merged = fragment.add_columns( + lambda _: pa.RecordBatch.from_pydict({"c": range(100)}) + ) + + expected = pa.Table.from_pydict({"a": range(100), "b": range(100), "c": range(100)}) + + merge = lance.LanceOperation.Merge([merged], expected.schema) + dataset = lance.LanceDataset._commit(base_dir, merge, read_version=1) + + tbl = dataset.to_table() + assert tbl == expected + + def test_data_files(tmp_path: Path): table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) base_dir = tmp_path / "test" @@ -411,7 +522,8 @@ def test_deletion_file(tmp_path: Path): assert new_fragment.deletion_file() is not None assert re.match("_deletions/0-1-[0-9]{1,32}.arrow", new_fragment.deletion_file()) print(type(new_fragment), new_fragment) - dataset = lance.LanceDataset._commit(base_dir, table.schema, [new_fragment]) + operation = lance.LanceOperation.Overwrite(table.schema, [new_fragment]) + dataset = lance.LanceDataset._commit(base_dir, operation) assert dataset.count_rows() == 90 @@ -429,7 +541,8 @@ def test_commit_fragments_via_scanner(tmp_path: Path): unpickled = pickle.loads(pickled) assert fragment_metadata == unpickled - dataset = lance.LanceDataset._commit(base_dir, table.schema, [fragment_metadata]) + operation = lance.LanceOperation.Overwrite(table.schema, [fragment_metadata]) + dataset = lance.LanceDataset._commit(base_dir, operation) assert dataset.schema == table.schema tbl = dataset.to_table() diff --git a/python/python/tests/test_fragment.py b/python/python/tests/test_fragment.py index 8a205773e4..9013070a02 100644 --- a/python/python/tests/test_fragment.py +++ b/python/python/tests/test_fragment.py @@ -20,7 +20,7 @@ import pandas as pd import pyarrow as pa import pytest -from lance import FragmentMetadata, LanceDataset, LanceFragment +from lance import FragmentMetadata, LanceDataset, LanceFragment, LanceOperation from lance.progress import FileSystemFragmentWriteProgress, FragmentWriteProgress @@ -50,7 +50,9 @@ def test_write_fragment_two_phases(tmp_path: Path): fragments = [FragmentMetadata.from_json(j) for j in json_array] schema = pa.schema([pa.field("a", pa.int64())]) - dataset = LanceDataset._commit(tmp_path, schema, fragments) + + operation = LanceOperation.Overwrite(schema, fragments) + dataset = LanceDataset._commit(tmp_path, operation) df = dataset.to_table().to_pandas() pd.testing.assert_frame_equal( diff --git a/python/src/dataset.rs b/python/src/dataset.rs index cabd954e18..de5a386f51 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -37,7 +37,8 @@ use crate::fragment::{FileFragment, FragmentMetadata}; use crate::Scanner; use crate::RT; use lance::dataset::{ - scanner::Scanner as LanceScanner, Dataset as LanceDataset, Version, WriteMode, WriteParams, + scanner::Scanner as LanceScanner, transaction::Operation as LanceOperation, + Dataset as LanceDataset, Version, WriteMode, WriteParams, }; use lance::index::{ vector::diskann::DiskANNParams, @@ -53,6 +54,94 @@ const DEFAULT_NPROBS: usize = 1; const DEFAULT_INDEX_CACHE_SIZE: usize = 256; const DEFAULT_METADATA_CACHE_SIZE: usize = 256; +#[pyclass(name = "_Operation", module = "_lib")] +#[derive(Clone)] +pub struct Operation(LanceOperation); + +fn into_fragments(fragments: Vec) -> Vec { + fragments + .into_iter() + .map(|f| f.inner) + .collect::>() +} + +fn convert_schema(arrow_schema: &ArrowSchema) -> PyResult { + Schema::try_from(arrow_schema).map_err(|e| { + PyValueError::new_err(format!( + "Failed to convert Arrow schema to Lance schema: {}", + e + )) + }) +} + +#[pymethods] +impl Operation { + fn __repr__(&self) -> String { + format!("{:?}", self.0) + } + + #[staticmethod] + fn overwrite( + schema: PyArrowType, + fragments: Vec, + ) -> PyResult { + let schema = convert_schema(&schema.0)?; + let fragments = into_fragments(fragments); + let op = LanceOperation::Overwrite { fragments, schema }; + Ok(Self(op)) + } + + #[staticmethod] + fn append(fragments: Vec) -> PyResult { + let fragments = into_fragments(fragments); + let op = LanceOperation::Append { fragments }; + Ok(Self(op)) + } + + #[staticmethod] + fn delete( + updated_fragments: Vec, + deleted_fragment_ids: Vec, + predicate: String, + ) -> PyResult { + let updated_fragments = into_fragments(updated_fragments); + let op = LanceOperation::Delete { + updated_fragments, + deleted_fragment_ids, + predicate, + }; + Ok(Self(op)) + } + + #[staticmethod] + fn rewrite( + old_fragments: Vec, + new_fragments: Vec, + ) -> PyResult { + let old_fragments = into_fragments(old_fragments); + let new_fragments = into_fragments(new_fragments); + let op = LanceOperation::Rewrite { + old_fragments, + new_fragments, + }; + Ok(Self(op)) + } + + #[staticmethod] + fn merge(fragments: Vec, schema: PyArrowType) -> PyResult { + let schema = convert_schema(&schema.0)?; + let fragments = into_fragments(fragments); + let op = LanceOperation::Merge { fragments, schema }; + Ok(Self(op)) + } + + #[staticmethod] + fn restore(version: u64) -> PyResult { + let op = LanceOperation::Restore { version }; + Ok(Self(op)) + } +} + /// Lance Dataset that will be wrapped by another class in Python #[pyclass(name = "_Dataset", module = "_lib")] #[derive(Clone)] @@ -499,27 +588,22 @@ impl Dataset { #[staticmethod] fn commit( dataset_uri: &str, - schema: &PyAny, - fragments: Vec<&PyAny>, - mode: Option<&str>, + operation: Operation, + read_version: Option, + commit_lock: Option<&PyAny>, ) -> PyResult { - let arrow_schema = ArrowSchema::from_pyarrow(schema)?; - let py = schema.py(); - let schema = Schema::try_from(&arrow_schema).map_err(|e| { - PyValueError::new_err(format!( - "Failed to convert Arrow schema to Lance schema: {}", - e - )) - })?; - let fragment_metadata = fragments - .iter() - .map(|f| f.extract::().map(|fm| fm.inner)) - .collect::>>()?; - let mode = parse_write_mode(mode.unwrap_or("create"))?; + let store_params = if let Some(commit_handler) = commit_lock { + let py_commit_lock = PyCommitLock::new(commit_handler.to_object(commit_handler.py())); + let mut object_store_params = ObjectStoreParams::default(); + object_store_params.set_commit_lock(Arc::new(py_commit_lock)); + Some(object_store_params) + } else { + None + }; let ds = RT .block_on( - Some(py), - LanceDataset::commit(dataset_uri, &schema, &fragment_metadata, mode), + commit_lock.map(|cl| cl.py()), + LanceDataset::commit(dataset_uri, operation.0, read_version, store_params), ) .map_err(|e| PyIOError::new_err(e.to_string()))?; Ok(Self { @@ -565,6 +649,19 @@ fn parse_write_mode(mode: &str) -> PyResult { } } +pub(crate) fn get_object_store_params(options: &PyDict) -> Option { + if options.is_none() { + None + } else if let Some(commit_handler) = options.get_item("commit_handler") { + let py_commit_lock = PyCommitLock::new(commit_handler.to_object(options.py())); + let mut object_store_params = ObjectStoreParams::default(); + object_store_params.set_commit_lock(Arc::new(py_commit_lock)); + Some(object_store_params) + } else { + None + } +} + pub(crate) fn get_write_params(options: &PyDict) -> PyResult> { let params = if options.is_none() { None @@ -579,13 +676,7 @@ pub(crate) fn get_write_params(options: &PyDict) -> PyResult if let Some(maybe_nrows) = options.get_item("max_rows_per_group") { p.max_rows_per_group = usize::extract(maybe_nrows)?; } - - if let Some(commit_handler) = options.get_item("commit_handler") { - let py_commit_lock = PyCommitLock::new(commit_handler.to_object(options.py())); - let mut object_store_params = ObjectStoreParams::default(); - object_store_params.set_commit_lock(Arc::new(py_commit_lock)); - p.store_params = Some(object_store_params); - } + p.store_params = get_object_store_params(options); Some(p) }; diff --git a/python/src/lib.rs b/python/src/lib.rs index e226d06671..abe17e77ed 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -48,7 +48,7 @@ pub(crate) mod updater; pub use crate::arrow::{bfloat16_array, BFloat16}; use crate::fragment::cleanup_partial_writes; pub use dataset::write_dataset; -pub use dataset::Dataset; +pub use dataset::{Dataset, Operation}; pub use fragment::FragmentMetadata; use fragment::{DataFile, FileFragment}; pub use reader::LanceReader; @@ -70,6 +70,7 @@ fn lance(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index c605e564b2..6db901a82c 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -559,45 +559,96 @@ impl Dataset { Ok(()) } - /// Create a new version of [`Dataset`] from a collection of fragments. + /// Commit changes to the dataset + /// + /// This operation is not needed if you are using append/write/delete to manipulate the dataset. + /// It is used to commit changes to the dataset that are made externally. For example, a bulk + /// import tool may import large amounts of new data and write the appropriate lance files + /// directly instead of using the write function. + /// + /// This method can be used to commit this change to the dataset's manifest. This method will + /// not verify that the provided fragments exist and correct, that is the caller's responsibility. + /// + /// If this commit is a change to an existing dataset then it will often need to be based on an + /// existing version of the dataset. For example, if this change is a `delete` operation then + /// the caller will have read in the existing data (at some version) to determine which fragments + /// need to be deleted. The base version that the caller used should be supplied as the `read_version` + /// parameter. Some operations (e.g. Overwrite) do not depend on a previous version and `read_version` + /// can be None. An error will be returned if the `read_version` is needed for an operation and + /// it is not specified. + /// + /// All operations except Overwrite will fail if the dataset does not already exist. + /// + /// # Arguments + /// + /// * `base_uri` - The base URI of the dataset + /// * `operation` - A description of the change to commit + /// * `read_version` - The version of the dataset that this change is based on + /// * `store_params` Parameters controlling object store access to the manifest pub async fn commit( base_uri: &str, - schema: &Schema, - fragments: &[Fragment], - mode: WriteMode, + operation: Operation, + read_version: Option, + store_params: Option, ) -> Result { - let (object_store, base) = ObjectStore::from_uri(base_uri).await?; + let read_version = read_version.map_or_else( + || match operation { + Operation::Overwrite { .. } | Operation::Restore { .. } => Ok(0), + _ => Err(Error::invalid_input( + "read_version must be specified for this operation", + )), + }, + Ok, + )?; + + let (object_store, base) = + ObjectStore::from_uri_and_params(base_uri, &store_params.clone().unwrap_or_default()) + .await?; + + // Test if the dataset exists let latest_manifest = object_store .commit_handler .resolve_latest_version(&base, &object_store) .await?; - let mut indices = vec![]; - let mut manifest = if object_store.exists(&latest_manifest).await? { - let dataset = Self::open(base_uri).await?; + let flag_dataset_exists = object_store.exists(&latest_manifest).await?; - if matches!(mode, WriteMode::Append) { - // Append mode: inherit indices from previous version. - indices = dataset.load_indices().await?; - } + if !flag_dataset_exists && !matches!(operation, Operation::Overwrite { .. }) { + return Err(Error::DatasetNotFound { + path: base.to_string(), + source: "The dataset must already exist unless the operation is Overwrite".into(), + }); + } - let dataset_schema = dataset.schema(); - let added_on_schema = schema.exclude(dataset_schema)?; - let schema = dataset_schema.merge(&added_on_schema)?; + let dataset = if flag_dataset_exists { + Some( + Self::open_with_params( + base_uri, + &ReadParams { + store_options: store_params.clone(), + ..Default::default() + }, + ) + .await?, + ) + } else { + None + }; + + let transaction = Transaction::new(read_version, operation, None); - Manifest::new_from_previous(&dataset.manifest, &schema, Arc::new(fragments.to_vec())) + let manifest = if let Some(dataset) = &dataset { + commit_transaction( + dataset, + &object_store, + &transaction, + &Default::default(), + &Default::default(), + ) + .await? } else { - Manifest::new(schema, Arc::new(fragments.to_vec())) + commit_new_dataset(&object_store, &base, &transaction, &Default::default()).await? }; - // Preserve indices. - write_manifest_file( - &object_store, - &base, - &mut manifest, - Some(indices), - &Default::default(), - ) - .await?; Ok(Self { object_store: Arc::new(object_store), base, diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index c190ea449f..51786426de 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -639,6 +639,7 @@ mod tests { use super::*; use crate::dataset::progress::NoopFragmentWriteProgress; + use crate::dataset::transaction::Operation; use crate::dataset::{WriteParams, ROW_ID}; async fn create_dataset(test_uri: &str) -> Dataset { @@ -845,14 +846,12 @@ mod tests { fragments.push(f) } - let new_dataset = Dataset::commit( - test_uri, - schema, - &fragments, - crate::dataset::WriteMode::Create, - ) - .await - .unwrap(); + let op = Operation::Overwrite { + schema: schema.clone(), + fragments, + }; + + let new_dataset = Dataset::commit(test_uri, op, None, None).await.unwrap(); assert_eq!(new_dataset.count_rows().await.unwrap(), dataset_rows); } @@ -918,14 +917,13 @@ mod tests { // Scan again let full_schema = dataset.schema().merge(new_schema.as_ref()).unwrap(); let before_version = dataset.version().version; - let dataset = Dataset::commit( - test_uri, - &full_schema, - &[new_fragment], - crate::dataset::WriteMode::Create, - ) - .await - .unwrap(); + + let op = Operation::Overwrite { + fragments: vec![new_fragment], + schema: full_schema.clone(), + }; + + let dataset = Dataset::commit(test_uri, op, None, None).await.unwrap(); // We only kept the first fragment of 40 rows assert_eq!(