Skip to content

Commit

Permalink
feat(rust): execute_uncommitted for merge_insert
Browse files Browse the repository at this point in the history
expose in python

refactor: make transaction marshalling easier

cleanup

fix tests

fix path backward compatibility

fix repr

get changes back
  • Loading branch information
wjones127 committed Dec 20, 2024
1 parent 10e6454 commit 1173939
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 79 deletions.
93 changes: 80 additions & 13 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Optional,
Sequence,
Set,
Tuple,
TypedDict,
Union,
)
Expand Down Expand Up @@ -102,6 +103,30 @@ def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None):

return super(MergeInsertBuilder, self).execute(reader)

def execute_uncommitted(
self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None
) -> Tuple[Transaction, Dict[str, Any]]:
"""Executes the merge insert operation without committing
This function updates the original dataset and returns a dictionary with
information about merge statistics - i.e. the number of inserted, updated,
and deleted rows.
Parameters
----------
data_obj: ReaderLike
The new data to use as the source table for the operation. This parameter
can be any source of data (e.g. table / dataset) that
:func:`~lance.write_dataset` accepts.
schema: Optional[pa.Schema]
The schema of the data. This only needs to be supplied whenever the data
source is some kind of generator.
"""
reader = _coerce_reader(data_obj, schema)

return super(MergeInsertBuilder, self).execute_uncommitted(reader)

# These next three overrides exist only to document the methods

def when_matched_update_all(
Expand Down Expand Up @@ -2200,7 +2225,7 @@ def _commit(
@staticmethod
def commit(
base_uri: Union[str, Path, LanceDataset],
operation: LanceOperation.BaseOperation,
operation: Union[LanceOperation.BaseOperation, Transaction],
read_version: Optional[int] = None,
commit_lock: Optional[CommitLock] = None,
storage_options: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -2305,24 +2330,44 @@ def commit(
f"commit_lock must be a function, got {type(commit_lock)}"
)

if read_version is None and not isinstance(
operation, (LanceOperation.Overwrite, LanceOperation.Restore)
if (
isinstance(operation, LanceOperation.BaseOperation)
and read_version is None
and not isinstance(
operation, (LanceOperation.Overwrite, LanceOperation.Restore)
)
):
raise ValueError(
"read_version is required for all operations except "
"Overwrite and Restore"
)

new_ds = _Dataset.commit(
base_uri,
operation,
read_version,
commit_lock,
storage_options=storage_options,
enable_v2_manifest_paths=enable_v2_manifest_paths,
detached=detached,
max_retries=max_retries,
)
if isinstance(operation, Transaction):
new_ds = _Dataset.commit_transaction(
base_uri,
operation,
commit_lock,
storage_options=storage_options,
enable_v2_manifest_paths=enable_v2_manifest_paths,
detached=detached,
max_retries=max_retries,
)
elif isinstance(operation, LanceOperation.BaseOperation):
new_ds = _Dataset.commit(
base_uri,
operation,
read_version,
commit_lock,
storage_options=storage_options,
enable_v2_manifest_paths=enable_v2_manifest_paths,
detached=detached,
max_retries=max_retries,
)
else:
raise TypeError(
"operation must be a LanceOperation.BaseOperation or Transaction, "
f"got {type(operation)}"
)
ds = LanceDataset.__new__(LanceDataset)
ds._storage_options = storage_options
ds._ds = new_ds
Expand Down Expand Up @@ -2666,6 +2711,28 @@ class Delete(BaseOperation):
def __post_init__(self):
LanceOperation._validate_fragments(self.updated_fragments)

@dataclass
class Update(BaseOperation):
"""
Operation that updates rows in the dataset.
Attributes
----------
removed_fragment_ids: list[int]
The ids of the fragments that have been removed entirely.
updated_fragments: list[FragmentMetadata]
The fragments that have been updated with new deletion vectors.
new_fragments: list[FragmentMetadata]
The fragments that contain the new rows.
"""

removed_fragment_ids: List[int]
updated_fragments: List[FragmentMetadata]
new_fragments: List[FragmentMetadata]

def __post_init__(self):
LanceOperation._validate_fragments(self.updated_fragments)
LanceOperation._validate_fragments(self.new_fragments)

@dataclass
class Merge(BaseOperation):
"""
Expand Down
25 changes: 25 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,31 @@ def test_restore_with_commit(tmp_path: Path):
assert tbl == table


def test_merge_insert_with_commit():
table = pa.table({"id": range(10), "updated": [False] * 10})
dataset = lance.write_dataset(table, "memory://test")

updates = pa.Table.from_pylist([{"id": 1, "updated": True}])
transaction, stats = (
dataset.merge_insert(on="id")
.when_matched_update_all()
.execute_uncommitted(updates)
)

assert isinstance(stats, dict)
assert stats["num_updated_rows"] == 1
assert stats["num_inserted_rows"] == 0
assert stats["num_deleted_rows"] == 0

assert isinstance(transaction, lance.Transaction)
assert isinstance(transaction.operation, lance.LanceOperation.Update)

dataset = lance.LanceDataset.commit(dataset, transaction)
assert dataset.to_table().sort_by("id") == pa.table(
{"id": range(10), "updated": [False] + [True] + [False] * 8}
)


def test_merge_with_commit(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"
Expand Down
79 changes: 66 additions & 13 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ use lance::dataset::{
WriteParams,
};
use lance::dataset::{
BatchInfo, BatchUDF, CommitBuilder, NewColumnTransform, UDFCheckpointStore, WriteDestination,
BatchInfo, BatchUDF, CommitBuilder, MergeStats, NewColumnTransform, UDFCheckpointStore,
WriteDestination,
};
use lance::dataset::{ColumnAlteration, ProjectionRequest};
use lance::index::{vector::VectorIndexParams, DatasetIndexInternalExt};
Expand Down Expand Up @@ -194,20 +195,46 @@ impl MergeInsertBuilder {
.try_build()
.map_err(|err| PyValueError::new_err(err.to_string()))?;

let new_self = RT
let (new_dataset, stats) = RT
.spawn(Some(py), job.execute_reader(new_data))?
.map_err(|err| PyIOError::new_err(err.to_string()))?;

let dataset = self.dataset.bind(py);

dataset.borrow_mut().ds = new_self.0;
let merge_stats = new_self.1;
let merge_dict = PyDict::new_bound(py);
merge_dict.set_item("num_inserted_rows", merge_stats.num_inserted_rows)?;
merge_dict.set_item("num_updated_rows", merge_stats.num_updated_rows)?;
merge_dict.set_item("num_deleted_rows", merge_stats.num_deleted_rows)?;
dataset.borrow_mut().ds = new_dataset;

Ok(Self::build_stats(&stats, py)?.into())
}

pub fn execute_uncommitted<'a>(
&mut self,
new_data: &Bound<'a, PyAny>,
) -> PyResult<(PyLance<Transaction>, Bound<'a, PyDict>)> {
let py = new_data.py();
let new_data = convert_reader(new_data)?;

Ok(merge_dict.into())
let job = self
.builder
.try_build()
.map_err(|err| PyValueError::new_err(err.to_string()))?;

let (transaction, stats) = RT
.spawn(Some(py), job.execute_uncommitted(new_data))?
.map_err(|err| PyIOError::new_err(err.to_string()))?;

let stats = Self::build_stats(&stats, py)?;

Ok((PyLance(transaction), stats))
}
}

impl MergeInsertBuilder {
fn build_stats<'a>(stats: &MergeStats, py: Python<'a>) -> PyResult<Bound<'a, PyDict>> {
let dict = PyDict::new_bound(py);
dict.set_item("num_inserted_rows", stats.num_inserted_rows)?;
dict.set_item("num_updated_rows", stats.num_updated_rows)?;
dict.set_item("num_deleted_rows", stats.num_deleted_rows)?;
Ok(dict)
}
}

Expand Down Expand Up @@ -1284,6 +1311,32 @@ impl Dataset {
enable_v2_manifest_paths: Option<bool>,
detached: Option<bool>,
max_retries: Option<u32>,
) -> PyResult<Self> {
let transaction =
Transaction::new(read_version.unwrap_or_default(), operation.0, None, None);

Self::commit_transaction(
dest,
PyLance(transaction),
commit_lock,
storage_options,
enable_v2_manifest_paths,
detached,
max_retries,
)
}

#[allow(clippy::too_many_arguments)]
#[staticmethod]
#[pyo3(signature = (dest, transaction, commit_lock = None, storage_options = None, enable_v2_manifest_paths = None, detached = None, max_retries = None))]
fn commit_transaction(
dest: &Bound<PyAny>,
transaction: PyLance<Transaction>,
commit_lock: Option<&Bound<'_, PyAny>>,
storage_options: Option<HashMap<String, String>>,
enable_v2_manifest_paths: Option<bool>,
detached: Option<bool>,
max_retries: Option<u32>,
) -> PyResult<Self> {
let object_store_params =
storage_options
Expand All @@ -1305,9 +1358,6 @@ impl Dataset {
WriteDestination::Uri(dest.extract()?)
};

let transaction =
Transaction::new(read_version.unwrap_or_default(), operation.0, None, None);

let mut builder = CommitBuilder::new(dest)
.enable_v2_manifest_paths(enable_v2_manifest_paths.unwrap_or(false))
.with_detached(detached.unwrap_or(false))
Expand All @@ -1322,7 +1372,10 @@ impl Dataset {
}

let ds = RT
.block_on(commit_lock.map(|cl| cl.py()), builder.execute(transaction))?
.block_on(
commit_lock.map(|cl| cl.py()),
builder.execute(transaction.0),
)?
.map_err(|err| PyIOError::new_err(err.to_string()))?;

let uri = ds.uri().to_string();
Expand Down
29 changes: 29 additions & 0 deletions python/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ impl FromPyObject<'_> for PyLance<Operation> {
};
Ok(Self(op))
}
"Update" => {
let removed_fragment_ids = ob.getattr("removed_fragment_ids")?.extract()?;

let updated_fragments = extract_vec(&ob.getattr("updated_fragments")?)?;

let new_fragments = extract_vec(&ob.getattr("new_fragments")?)?;

let op = Operation::Update {
removed_fragment_ids,
updated_fragments,
new_fragments,
};
Ok(Self(op))
}
"Merge" => {
let schema = extract_schema(&ob.getattr("schema")?)?;

Expand Down Expand Up @@ -126,6 +140,21 @@ impl ToPyObject for PyLance<&Operation> {
.expect("Failed to get Append class");
cls.call1((fragments,)).unwrap().to_object(py)
}
Operation::Update {
removed_fragment_ids,
updated_fragments,
new_fragments,
} => {
let removed_fragment_ids = removed_fragment_ids.to_object(py);
let updated_fragments = export_vec(py, updated_fragments.as_slice());
let new_fragments = export_vec(py, new_fragments.as_slice());
let cls = namespace
.getattr("Update")
.expect("Failed to get Update class");
cls.call1((removed_fragment_ids, updated_fragments, new_fragments))
.unwrap()
.to_object(py)
}
_ => todo!(),
}
}
Expand Down
3 changes: 2 additions & 1 deletion rust/lance/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ pub use schema_evolution::{
};
pub use take::TakeBuilder;
pub use write::merge_insert::{
MergeInsertBuilder, MergeInsertJob, WhenMatched, WhenNotMatched, WhenNotMatchedBySource,
MergeInsertBuilder, MergeInsertJob, MergeStats, WhenMatched, WhenNotMatched,
WhenNotMatchedBySource,
};
pub use write::update::{UpdateBuilder, UpdateJob};
#[allow(deprecated)]
Expand Down
Loading

0 comments on commit 1173939

Please sign in to comment.