Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update the commit function to work with transactions #1193

Merged
9 changes: 8 additions & 1 deletion python/python/lance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@
pd = None
ts_types = Union[datetime, str]

from .dataset import LanceDataset, LanceScanner, __version__, write_dataset
from .dataset import (
LanceDataset,
LanceOperation,
LanceScanner,
__version__,
write_dataset,
)
from .fragment import FragmentMetadata, LanceFragment
from .schema import json_to_schema, schema_to_json
from .util import sanitize_ts

__all__ = [
"LanceDataset",
"LanceOperation",
"LanceScanner",
"__version__",
"write_dataset",
Expand Down
144 changes: 127 additions & 17 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import json
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta
from functools import lru_cache
from pathlib import Path
Expand All @@ -31,7 +33,7 @@
from .commit import CommitLock
from .fragment import FragmentMetadata, LanceFragment
from .lance import __version__ as __version__
from .lance import _Dataset, _Scanner, _write_dataset
from .lance import _Dataset, _Operation, _Scanner, _write_dataset

try:
import pandas as pd
Expand Down Expand Up @@ -680,21 +682,37 @@ def create_index(
@staticmethod
def _commit(
base_uri: Union[str, Path],
new_schema: pa.Schema,
fragments: Iterable[FragmentMetadata],
mode: str = "append",
operation: LanceOperation.BaseOperation,
read_version: Optional[int] = None,
commit_lock: Optional[CommitLock] = None,
) -> LanceDataset:
"""Create a new version of dataset with collected fragments.
"""Create a new version of dataset

This method is an advanced method which allows users to describe a change
that has been made to the data files. This method is not needed when using
Lance to apply changes (e.g. when using :py:class:`LanceDataset` or
:py:func:`write_dataset`.)

It's current purpose is to allow for changes being made in a distributed
environment where no single process is doing all of the work. For example,
a distributed bulk update or a distributed bulk modify operation.

This method allows users to commit a version of dataset in a distributed
environment.
Once all of the changes have been made, this method can be called to make
the changes visible by updating the dataset manifest.

Parameters
----------
new_schema : pa.Schema
The schema for the new version of dataset.
fragments : list[FragmentMetadata]
The fragments to create new version of dataset.
base_uri: str or Path
The base uri of the dataset
operation: BaseOperation
The operation to apply to the dataset. This describes what changes
have been made.
read_version: int, optional
The version of the dataset that was used as the base for the changes.
This is not needed for overwrite or restore operations.
commit_lock : CommitLock, optional
A custom commit lock. Only needed if your object store does not support
atomic commits. See the user guide for more details.

Returns
-------
Expand All @@ -708,8 +726,21 @@ def _commit(
"""
if isinstance(base_uri, Path):
base_uri = str(base_uri)
if not isinstance(new_schema, pa.Schema):
raise TypeError(f"schema must be pyarrow.Schema, got {type(new_schema)}")

if commit_lock:
if not callable(commit_lock):
raise TypeError(
f"commit_lock must be a function, got {type(commit_lock)}"
)

_Dataset.commit(base_uri, operation._to_inner(), read_version, commit_lock)
return LanceDataset(base_uri)


# LanceOperation is a namespace for operations that can be applied to a dataset.
class LanceOperation:
wjones127 marked this conversation as resolved.
Show resolved Hide resolved
@staticmethod
def _validate_fragments(fragments):
if not isinstance(fragments, list):
raise TypeError(
f"fragments must be list[FragmentMetadata], got {type(fragments)}"
Expand All @@ -720,10 +751,86 @@ def _commit(
raise TypeError(
f"fragments must be list[FragmentMetadata], got {type(fragments[0])}"
)
raw_fragments = [f._metadata for f in fragments]
# TODO: make fragments as a generator
_Dataset.commit(base_uri, new_schema, raw_fragments)
return LanceDataset(base_uri)

class BaseOperation(ABC):
@abstractmethod
def _to_inner(self):
raise NotImplementedError()

@dataclass
class Overwrite(BaseOperation):
new_schema: pa.Schema
fragments: Iterable[FragmentMetadata]

def __post_init__(self):
if not isinstance(self.new_schema, pa.Schema):
raise TypeError(
f"schema must be pyarrow.Schema, got {type(self.new_schema)}"
)
LanceOperation._validate_fragments(self.fragments)

def _to_inner(self):
raw_fragments = [f._metadata for f in self.fragments]
return _Operation.overwrite(self.new_schema, raw_fragments)

@dataclass
class Append(BaseOperation):
fragments: Iterable[FragmentMetadata]

def __post_init__(self):
LanceOperation._validate_fragments(self.fragments)

def _to_inner(self):
raw_fragments = [f._metadata for f in self.fragments]
return _Operation.append(raw_fragments)

@dataclass
class Delete(BaseOperation):
updated_fragments: Iterable[FragmentMetadata]
deleted_fragment_ids: Iterable[int]
predicate: str

def __post_init__(self):
LanceOperation._validate_fragments(self.updated_fragments)

def _to_inner(self):
raw_updated_fragments = [f._metadata for f in self.updated_fragments]
return _Operation.delete(
raw_updated_fragments, self.deleted_fragment_ids, self.predicate
)

@dataclass
class Rewrite(BaseOperation):
old_fragments: Iterable[FragmentMetadata]
new_fragments: Iterable[FragmentMetadata]
Comment on lines +804 to +805
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI this is changing in #1095. If we merge this first, let's just remove this operation from public API for now.


def __post_init__(self):
LanceOperation._validate_fragments(self.old_fragments)
LanceOperation._validate_fragments(self.new_fragments)

def _to_inner(self):
raw_old_fragments = [f._metadata for f in self.old_fragments]
raw_new_fragments = [f._metadata for f in self.new_fragments]
return _Operation.rewrite(raw_old_fragments, raw_new_fragments)

@dataclass
class Merge(BaseOperation):
fragments: Iterable[FragmentMetadata]
schema: pa.Schema

def __post_init__(self):
LanceOperation._validate_fragments(self.fragments)

def _to_inner(self):
raw_fragments = [f._metadata for f in self.fragments]
return _Operation.merge(raw_fragments, self.schema)

@dataclass
class Restore(BaseOperation):
version: int

def _to_inner(self):
return _Operation.restore(self.version)


class ScannerBuilder:
Expand Down Expand Up @@ -991,6 +1098,9 @@ def write_dataset(
The max number of rows to write before starting a new file
max_rows_per_group: int, default 1024
The max number of rows before starting a new group (in the same file)
commit_lock : CommitLock, optional
A custom commit lock. Only needed if your object store does not support
atomic commits. See the user guide for more details.

"""
reader = _coerce_reader(data_obj, schema)
Expand Down
2 changes: 1 addition & 1 deletion python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,4 +347,4 @@ def metadata(self) -> FragmentMetadata:
FragmentMetadata
"""

return FragmentMetadata(self._fragment.metadata())
return FragmentMetadata(self._fragment.metadata().json())
123 changes: 118 additions & 5 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@ def adder(batch: pa.RecordBatch) -> pa.RecordBatch:
fragment_metadata = fragment.add_columns(adder, columns=["a"])
schema = dataset.schema.append(pa.field("c", pa.int64()))

dataset = lance.LanceDataset._commit(base_dir, schema, [fragment_metadata])
operation = lance.LanceOperation.Overwrite(schema, [fragment_metadata])
dataset = lance.LanceDataset._commit(base_dir, operation)
assert dataset.schema == schema

tbl = dataset.to_table()
Expand All @@ -371,16 +372,126 @@ def adder(batch: pa.RecordBatch) -> pa.RecordBatch:
)


def test_create_from_fragments(tmp_path: Path):
def test_create_from_commit(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"
fragment = lance.fragment.LanceFragment.create(base_dir, table)

dataset = lance.LanceDataset._commit(base_dir, table.schema, [fragment])
operation = lance.LanceOperation.Overwrite(table.schema, [fragment])
dataset = lance.LanceDataset._commit(base_dir, operation)
tbl = dataset.to_table()
assert tbl == table


def test_append_with_commit(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"

lance.write_dataset(table, base_dir)

fragment = lance.fragment.LanceFragment.create(base_dir, table)
append = lance.LanceOperation.Append([fragment])

with pytest.raises(OSError):
# Must specify read version
dataset = lance.LanceDataset._commit(base_dir, append)

dataset = lance.LanceDataset._commit(base_dir, append, read_version=1)

tbl = dataset.to_table()

expected = pa.Table.from_pydict(
{
"a": list(range(100)) + list(range(100)),
"b": list(range(100)) + list(range(100)),
}
)
assert tbl == expected


def test_delete_with_commit(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"

lance.write_dataset(table, base_dir)
lance.write_dataset(table, base_dir, mode="append")

half_table = pa.Table.from_pydict({"a": range(50), "b": range(50)})

fragments = lance.dataset(base_dir).get_fragments()

updated_fragment = fragments[0].delete("a >= 50")
delete = lance.LanceOperation.Delete(
[updated_fragment], [fragments[1].fragment_id], "hello"
)

dataset = lance.LanceDataset._commit(base_dir, delete, read_version=2)

tbl = dataset.to_table()
assert tbl == half_table


def test_rewrite_with_commit(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"

lance.write_dataset(table, base_dir)
lance.write_dataset(table, base_dir, mode="append")

combined = pa.Table.from_pydict(
{
"a": list(range(100)) + list(range(100)),
"b": list(range(100)) + list(range(100)),
}
)

to_be_rewrote = [lf.metadata for lf in lance.dataset(base_dir).get_fragments()]

fragment = lance.fragment.LanceFragment.create(base_dir, combined)
rewrite = lance.LanceOperation.Rewrite(to_be_rewrote, [fragment])

dataset = lance.LanceDataset._commit(base_dir, rewrite, read_version=1)

tbl = dataset.to_table()
assert tbl == combined

assert len(lance.dataset(base_dir).get_fragments()) == 1


def test_restore_with_commit(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"

lance.write_dataset(table, base_dir)
lance.write_dataset(table, base_dir, mode="append")

restore = lance.LanceOperation.Restore(1)
dataset = lance.LanceDataset._commit(base_dir, restore)

tbl = dataset.to_table()
assert tbl == table


def test_merge_with_commit(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"

lance.write_dataset(table, base_dir)

fragment = lance.dataset(base_dir).get_fragments()[0]
merged = fragment.add_columns(
lambda _: pa.RecordBatch.from_pydict({"c": range(100)})
)

expected = pa.Table.from_pydict({"a": range(100), "b": range(100), "c": range(100)})

merge = lance.LanceOperation.Merge([merged], expected.schema)
dataset = lance.LanceDataset._commit(base_dir, merge, read_version=1)

tbl = dataset.to_table()
assert tbl == expected


def test_data_files(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"
Expand Down Expand Up @@ -411,7 +522,8 @@ def test_deletion_file(tmp_path: Path):
assert new_fragment.deletion_file() is not None
assert re.match("_deletions/0-1-[0-9]{1,32}.arrow", new_fragment.deletion_file())
print(type(new_fragment), new_fragment)
dataset = lance.LanceDataset._commit(base_dir, table.schema, [new_fragment])
operation = lance.LanceOperation.Overwrite(table.schema, [new_fragment])
dataset = lance.LanceDataset._commit(base_dir, operation)
assert dataset.count_rows() == 90


Expand All @@ -429,7 +541,8 @@ def test_commit_fragments_via_scanner(tmp_path: Path):
unpickled = pickle.loads(pickled)
assert fragment_metadata == unpickled

dataset = lance.LanceDataset._commit(base_dir, table.schema, [fragment_metadata])
operation = lance.LanceOperation.Overwrite(table.schema, [fragment_metadata])
dataset = lance.LanceDataset._commit(base_dir, operation)
assert dataset.schema == table.schema

tbl = dataset.to_table()
Expand Down
6 changes: 4 additions & 2 deletions python/python/tests/test_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pandas as pd
import pyarrow as pa
import pytest
from lance import FragmentMetadata, LanceDataset, LanceFragment
from lance import FragmentMetadata, LanceDataset, LanceFragment, LanceOperation
from lance.progress import FileSystemFragmentWriteProgress, FragmentWriteProgress


Expand Down Expand Up @@ -50,7 +50,9 @@ def test_write_fragment_two_phases(tmp_path: Path):
fragments = [FragmentMetadata.from_json(j) for j in json_array]

schema = pa.schema([pa.field("a", pa.int64())])
dataset = LanceDataset._commit(tmp_path, schema, fragments)

operation = LanceOperation.Overwrite(schema, fragments)
dataset = LanceDataset._commit(tmp_path, operation)

df = dataset.to_table().to_pandas()
pd.testing.assert_frame_equal(
Expand Down
Loading