diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index d2d03bd76a..988b8e2322 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -751,6 +751,10 @@ def _commit( _Dataset.commit(base_uri, operation._to_inner(), read_version, commit_lock) return LanceDataset(base_uri) + @property + def optimize(self) -> "DatasetOptimizer": + return DatasetOptimizer(self) + # LanceOperation is a namespace for operations that can be applied to a dataset. class LanceOperation: @@ -816,17 +820,35 @@ def _to_inner(self): @dataclass class Rewrite(BaseOperation): - old_fragments: Iterable[FragmentMetadata] - new_fragments: Iterable[FragmentMetadata] + """ + Operation that rewrites fragments but does not change the data within them. - def __post_init__(self): - LanceOperation._validate_fragments(self.old_fragments) - LanceOperation._validate_fragments(self.new_fragments) + This is for rearranging the data. + + The data are grouped, such that each group contains the old fragments + and the new fragments those are rewritten into. + """ + + groups: Iterable[RewriteGroup] + + @dataclass + class RewriteGroup: + old_fragments: Iterable[FragmentMetadata] + new_fragments: Iterable[FragmentMetadata] + + def __post_init__(self): + LanceOperation._validate_fragments(self.old_fragments) + LanceOperation._validate_fragments(self.new_fragments) def _to_inner(self): - raw_old_fragments = [f._metadata for f in self.old_fragments] - raw_new_fragments = [f._metadata for f in self.new_fragments] - return _Operation.rewrite(raw_old_fragments, raw_new_fragments) + groups = [ + ( + [f._metadata for f in g.old_fragments], + [f._metadata for f in g.new_fragments], + ) + for g in self.groups + ] + return _Operation.rewrite(groups) @dataclass class Merge(BaseOperation): @@ -847,10 +869,6 @@ class Restore(BaseOperation): def _to_inner(self): return _Operation.restore(self.version) - @property - def optimize(self) -> "DatasetOptimizer": - return DatasetOptimizer(self) - class ScannerBuilder: def __init__(self, ds: LanceDataset): diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 50f7b8a3b9..d97c903021 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -460,7 +460,8 @@ def test_rewrite_with_commit(tmp_path: Path): to_be_rewrote = [lf.metadata for lf in lance.dataset(base_dir).get_fragments()] fragment = lance.fragment.LanceFragment.create(base_dir, combined) - rewrite = lance.LanceOperation.Rewrite(to_be_rewrote, [fragment]) + group = lance.LanceOperation.Rewrite.RewriteGroup(to_be_rewrote, [fragment]) + rewrite = lance.LanceOperation.Rewrite([group]) dataset = lance.LanceDataset._commit(base_dir, rewrite, read_version=1) diff --git a/python/src/dataset.rs b/python/src/dataset.rs index decb3985d2..6ccf865adb 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -21,7 +21,7 @@ use arrow_array::{Float32Array, RecordBatch}; use arrow_data::ArrayData; use arrow_schema::Schema as ArrowSchema; use lance::arrow::as_fixed_size_list_array; - +use lance::dataset::transaction::RewriteGroup; use lance::dataset::{ fragment::FileFragment as LanceFileFragment, scanner::Scanner as LanceScanner, transaction::Operation as LanceOperation, Dataset as LanceDataset, ReadParams, Version, @@ -116,16 +116,15 @@ impl Operation { } #[staticmethod] - fn rewrite( - old_fragments: Vec, - new_fragments: Vec, - ) -> PyResult { - let old_fragments = into_fragments(old_fragments); - let new_fragments = into_fragments(new_fragments); - let op = LanceOperation::Rewrite { - old_fragments, - new_fragments, - }; + fn rewrite(groups: Vec<(Vec, Vec)>) -> PyResult { + let groups = groups + .into_iter() + .map(|(old_fragments, new_fragments)| RewriteGroup { + old_fragments: into_fragments(old_fragments), + new_fragments: into_fragments(new_fragments), + }) + .collect::>(); + let op = LanceOperation::Rewrite { groups }; Ok(Self(op)) }