Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: execute_uncommitted for merge insert #3233

Merged
merged 2 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 83 additions & 14 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Optional,
Sequence,
Set,
Tuple,
TypedDict,
Union,
)
Expand Down Expand Up @@ -102,6 +103,30 @@ def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None):

return super(MergeInsertBuilder, self).execute(reader)

def execute_uncommitted(
self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None
) -> Tuple[Transaction, Dict[str, Any]]:
"""Executes the merge insert operation without committing

This function updates the original dataset and returns a dictionary with
information about merge statistics - i.e. the number of inserted, updated,
and deleted rows.

Parameters
----------

data_obj: ReaderLike
The new data to use as the source table for the operation. This parameter
can be any source of data (e.g. table / dataset) that
:func:`~lance.write_dataset` accepts.
schema: Optional[pa.Schema]
The schema of the data. This only needs to be supplied whenever the data
source is some kind of generator.
"""
reader = _coerce_reader(data_obj, schema)

return super(MergeInsertBuilder, self).execute_uncommitted(reader)

# These next three overrides exist only to document the methods

def when_matched_update_all(
Expand Down Expand Up @@ -2220,7 +2245,7 @@ def _commit(
@staticmethod
def commit(
base_uri: Union[str, Path, LanceDataset],
operation: LanceOperation.BaseOperation,
operation: Union[LanceOperation.BaseOperation, Transaction],
blobs_op: Optional[LanceOperation.BaseOperation] = None,
read_version: Optional[int] = None,
commit_lock: Optional[CommitLock] = None,
Expand Down Expand Up @@ -2326,24 +2351,45 @@ def commit(
f"commit_lock must be a function, got {type(commit_lock)}"
)

if read_version is None and not isinstance(
operation, (LanceOperation.Overwrite, LanceOperation.Restore)
if (
isinstance(operation, LanceOperation.BaseOperation)
and read_version is None
and not isinstance(
operation, (LanceOperation.Overwrite, LanceOperation.Restore)
)
):
raise ValueError(
"read_version is required for all operations except "
"Overwrite and Restore"
)
new_ds = _Dataset.commit(
base_uri,
operation,
blobs_op,
read_version,
commit_lock,
storage_options=storage_options,
enable_v2_manifest_paths=enable_v2_manifest_paths,
detached=detached,
max_retries=max_retries,
)
if isinstance(operation, Transaction):
new_ds = _Dataset.commit_transaction(
base_uri,
operation,
commit_lock,
storage_options=storage_options,
enable_v2_manifest_paths=enable_v2_manifest_paths,
detached=detached,
max_retries=max_retries,
)
elif isinstance(operation, LanceOperation.BaseOperation):
new_ds = _Dataset.commit(
base_uri,
operation,
blobs_op,
read_version,
commit_lock,
storage_options=storage_options,
enable_v2_manifest_paths=enable_v2_manifest_paths,
detached=detached,
max_retries=max_retries,
)
else:
raise TypeError(
"operation must be a LanceOperation.BaseOperation or Transaction, "
f"got {type(operation)}"
)

ds = LanceDataset.__new__(LanceDataset)
ds._storage_options = storage_options
ds._ds = new_ds
Expand Down Expand Up @@ -2722,6 +2768,29 @@ class Delete(BaseOperation):
def __post_init__(self):
LanceOperation._validate_fragments(self.updated_fragments)

@dataclass
class Update(BaseOperation):
"""
Operation that updates rows in the dataset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make a comment that this operation should not insert new rows? Or would that not be a bad thing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it's used by merge-insert for upsert. So it's allowed to insert new rows.


Attributes
----------
removed_fragment_ids: list[int]
The ids of the fragments that have been removed entirely.
updated_fragments: list[FragmentMetadata]
The fragments that have been updated with new deletion vectors.
new_fragments: list[FragmentMetadata]
The fragments that contain the new rows.
"""

removed_fragment_ids: List[int]
updated_fragments: List[FragmentMetadata]
new_fragments: List[FragmentMetadata]

def __post_init__(self):
LanceOperation._validate_fragments(self.updated_fragments)
LanceOperation._validate_fragments(self.new_fragments)

@dataclass
class Merge(BaseOperation):
"""
Expand Down
25 changes: 25 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,31 @@ def test_restore_with_commit(tmp_path: Path):
assert tbl == table


def test_merge_insert_with_commit():
table = pa.table({"id": range(10), "updated": [False] * 10})
dataset = lance.write_dataset(table, "memory://test")

updates = pa.Table.from_pylist([{"id": 1, "updated": True}])
transaction, stats = (
dataset.merge_insert(on="id")
.when_matched_update_all()
.execute_uncommitted(updates)
)

assert isinstance(stats, dict)
assert stats["num_updated_rows"] == 1
assert stats["num_inserted_rows"] == 0
assert stats["num_deleted_rows"] == 0

assert isinstance(transaction, lance.Transaction)
assert isinstance(transaction.operation, lance.LanceOperation.Update)

dataset = lance.LanceDataset.commit(dataset, transaction)
assert dataset.to_table().sort_by("id") == pa.table(
{"id": range(10), "updated": [False] + [True] + [False] * 8}
)


def test_merge_with_commit(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"
Expand Down
87 changes: 70 additions & 17 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ use lance::dataset::{
WriteParams,
};
use lance::dataset::{
BatchInfo, BatchUDF, CommitBuilder, NewColumnTransform, UDFCheckpointStore, WriteDestination,
BatchInfo, BatchUDF, CommitBuilder, MergeStats, NewColumnTransform, UDFCheckpointStore,
WriteDestination,
};
use lance::dataset::{ColumnAlteration, ProjectionRequest};
use lance::index::vector::utils::get_vector_type;
Expand Down Expand Up @@ -199,20 +200,46 @@ impl MergeInsertBuilder {
.try_build()
.map_err(|err| PyValueError::new_err(err.to_string()))?;

let new_self = RT
let (new_dataset, stats) = RT
.spawn(Some(py), job.execute_reader(new_data))?
.map_err(|err| PyIOError::new_err(err.to_string()))?;

let dataset = self.dataset.bind(py);

dataset.borrow_mut().ds = new_self.0;
let merge_stats = new_self.1;
let merge_dict = PyDict::new_bound(py);
merge_dict.set_item("num_inserted_rows", merge_stats.num_inserted_rows)?;
merge_dict.set_item("num_updated_rows", merge_stats.num_updated_rows)?;
merge_dict.set_item("num_deleted_rows", merge_stats.num_deleted_rows)?;
dataset.borrow_mut().ds = new_dataset;

Ok(merge_dict.into())
Ok(Self::build_stats(&stats, py)?.into())
}

pub fn execute_uncommitted<'a>(
&mut self,
new_data: &Bound<'a, PyAny>,
) -> PyResult<(PyLance<Transaction>, Bound<'a, PyDict>)> {
let py = new_data.py();
let new_data = convert_reader(new_data)?;

let job = self
.builder
.try_build()
.map_err(|err| PyValueError::new_err(err.to_string()))?;

let (transaction, stats) = RT
.spawn(Some(py), job.execute_uncommitted(new_data))?
.map_err(|err| PyIOError::new_err(err.to_string()))?;

let stats = Self::build_stats(&stats, py)?;
Comment on lines +218 to +230
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit: this could be encapsulated in a helper method to cut down on repetition as it is shared by the committed variant. This might also help avoid issues in the future where we change one but not the other and don't notice.


Ok((PyLance(transaction), stats))
}
}

impl MergeInsertBuilder {
fn build_stats<'a>(stats: &MergeStats, py: Python<'a>) -> PyResult<Bound<'a, PyDict>> {
let dict = PyDict::new_bound(py);
dict.set_item("num_inserted_rows", stats.num_inserted_rows)?;
dict.set_item("num_updated_rows", stats.num_updated_rows)?;
dict.set_item("num_deleted_rows", stats.num_deleted_rows)?;
Ok(dict)
}
}

Expand Down Expand Up @@ -1312,6 +1339,36 @@ impl Dataset {
enable_v2_manifest_paths: Option<bool>,
detached: Option<bool>,
max_retries: Option<u32>,
) -> PyResult<Self> {
let transaction = Transaction::new(
read_version.unwrap_or_default(),
operation.0,
blobs_op.map(|op| op.0),
None,
);

Self::commit_transaction(
dest,
PyLance(transaction),
commit_lock,
storage_options,
enable_v2_manifest_paths,
detached,
max_retries,
)
}

#[allow(clippy::too_many_arguments)]
#[staticmethod]
#[pyo3(signature = (dest, transaction, commit_lock = None, storage_options = None, enable_v2_manifest_paths = None, detached = None, max_retries = None))]
fn commit_transaction(
dest: &Bound<PyAny>,
transaction: PyLance<Transaction>,
commit_lock: Option<&Bound<'_, PyAny>>,
storage_options: Option<HashMap<String, String>>,
enable_v2_manifest_paths: Option<bool>,
detached: Option<bool>,
max_retries: Option<u32>,
) -> PyResult<Self> {
let object_store_params =
storage_options
Expand All @@ -1333,13 +1390,6 @@ impl Dataset {
WriteDestination::Uri(dest.extract()?)
};

let transaction = Transaction::new(
read_version.unwrap_or_default(),
operation.0,
blobs_op.map(|op| op.0),
None,
);

let mut builder = CommitBuilder::new(dest)
.enable_v2_manifest_paths(enable_v2_manifest_paths.unwrap_or(false))
.with_detached(detached.unwrap_or(false))
Expand All @@ -1354,7 +1404,10 @@ impl Dataset {
}

let ds = RT
.block_on(commit_lock.map(|cl| cl.py()), builder.execute(transaction))?
.block_on(
commit_lock.map(|cl| cl.py()),
builder.execute(transaction.0),
)?
.map_err(|err| PyIOError::new_err(err.to_string()))?;

let uri = ds.uri().to_string();
Expand Down
29 changes: 29 additions & 0 deletions python/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ impl FromPyObject<'_> for PyLance<Operation> {
};
Ok(Self(op))
}
"Update" => {
let removed_fragment_ids = ob.getattr("removed_fragment_ids")?.extract()?;

let updated_fragments = extract_vec(&ob.getattr("updated_fragments")?)?;

let new_fragments = extract_vec(&ob.getattr("new_fragments")?)?;

let op = Operation::Update {
removed_fragment_ids,
updated_fragments,
new_fragments,
};
Ok(Self(op))
}
"Merge" => {
let schema = extract_schema(&ob.getattr("schema")?)?;

Expand Down Expand Up @@ -143,6 +157,21 @@ impl ToPyObject for PyLance<&Operation> {
.expect("Failed to create Overwrite instance")
.to_object(py)
}
Operation::Update {
removed_fragment_ids,
updated_fragments,
new_fragments,
} => {
let removed_fragment_ids = removed_fragment_ids.to_object(py);
let updated_fragments = export_vec(py, updated_fragments.as_slice());
let new_fragments = export_vec(py, new_fragments.as_slice());
let cls = namespace
.getattr("Update")
.expect("Failed to get Update class");
cls.call1((removed_fragment_ids, updated_fragments, new_fragments))
.unwrap()
.to_object(py)
}
_ => todo!(),
}
}
Expand Down
3 changes: 2 additions & 1 deletion rust/lance/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ pub use schema_evolution::{
};
pub use take::TakeBuilder;
pub use write::merge_insert::{
MergeInsertBuilder, MergeInsertJob, WhenMatched, WhenNotMatched, WhenNotMatchedBySource,
MergeInsertBuilder, MergeInsertJob, MergeStats, WhenMatched, WhenNotMatched,
WhenNotMatchedBySource,
};
pub use write::update::{UpdateBuilder, UpdateJob};
#[allow(deprecated)]
Expand Down
Loading
Loading