Skip to content

Commit

Permalink
feat: update the commit function to work with transactions (#1193)
Browse files Browse the repository at this point in the history
Previously the commit function did not use the new transactions feature.
This meant that a commit running concurrently with another operation
could lead to bad behavior. The commit operation was also limited mostly
to overwrites.

Now commits can apply any kind of transaction operation.

---------

Co-authored-by: Will Jones <[email protected]>
  • Loading branch information
westonpace and wjones127 authored Sep 8, 2023
1 parent 5ab23cf commit 8c46445
Show file tree
Hide file tree
Showing 9 changed files with 468 additions and 95 deletions.
9 changes: 8 additions & 1 deletion python/python/lance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@
pd = None
ts_types = Union[datetime, str]

from .dataset import LanceDataset, LanceScanner, __version__, write_dataset
from .dataset import (
LanceDataset,
LanceOperation,
LanceScanner,
__version__,
write_dataset,
)
from .fragment import FragmentMetadata, LanceFragment
from .schema import json_to_schema, schema_to_json
from .util import sanitize_ts

__all__ = [
"LanceDataset",
"LanceOperation",
"LanceScanner",
"__version__",
"write_dataset",
Expand Down
144 changes: 127 additions & 17 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import json
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta
from functools import lru_cache
from pathlib import Path
Expand All @@ -31,7 +33,7 @@
from .commit import CommitLock
from .fragment import FragmentMetadata, LanceFragment
from .lance import __version__ as __version__
from .lance import _Dataset, _Scanner, _write_dataset
from .lance import _Dataset, _Operation, _Scanner, _write_dataset

try:
import pandas as pd
Expand Down Expand Up @@ -680,21 +682,37 @@ def create_index(
@staticmethod
def _commit(
base_uri: Union[str, Path],
new_schema: pa.Schema,
fragments: Iterable[FragmentMetadata],
mode: str = "append",
operation: LanceOperation.BaseOperation,
read_version: Optional[int] = None,
commit_lock: Optional[CommitLock] = None,
) -> LanceDataset:
"""Create a new version of dataset with collected fragments.
"""Create a new version of dataset
This method is an advanced method which allows users to describe a change
that has been made to the data files. This method is not needed when using
Lance to apply changes (e.g. when using :py:class:`LanceDataset` or
:py:func:`write_dataset`.)
It's current purpose is to allow for changes being made in a distributed
environment where no single process is doing all of the work. For example,
a distributed bulk update or a distributed bulk modify operation.
This method allows users to commit a version of dataset in a distributed
environment.
Once all of the changes have been made, this method can be called to make
the changes visible by updating the dataset manifest.
Parameters
----------
new_schema : pa.Schema
The schema for the new version of dataset.
fragments : list[FragmentMetadata]
The fragments to create new version of dataset.
base_uri: str or Path
The base uri of the dataset
operation: BaseOperation
The operation to apply to the dataset. This describes what changes
have been made.
read_version: int, optional
The version of the dataset that was used as the base for the changes.
This is not needed for overwrite or restore operations.
commit_lock : CommitLock, optional
A custom commit lock. Only needed if your object store does not support
atomic commits. See the user guide for more details.
Returns
-------
Expand All @@ -708,8 +726,21 @@ def _commit(
"""
if isinstance(base_uri, Path):
base_uri = str(base_uri)
if not isinstance(new_schema, pa.Schema):
raise TypeError(f"schema must be pyarrow.Schema, got {type(new_schema)}")

if commit_lock:
if not callable(commit_lock):
raise TypeError(
f"commit_lock must be a function, got {type(commit_lock)}"
)

_Dataset.commit(base_uri, operation._to_inner(), read_version, commit_lock)
return LanceDataset(base_uri)


# LanceOperation is a namespace for operations that can be applied to a dataset.
class LanceOperation:
@staticmethod
def _validate_fragments(fragments):
if not isinstance(fragments, list):
raise TypeError(
f"fragments must be list[FragmentMetadata], got {type(fragments)}"
Expand All @@ -720,10 +751,86 @@ def _commit(
raise TypeError(
f"fragments must be list[FragmentMetadata], got {type(fragments[0])}"
)
raw_fragments = [f._metadata for f in fragments]
# TODO: make fragments as a generator
_Dataset.commit(base_uri, new_schema, raw_fragments)
return LanceDataset(base_uri)

class BaseOperation(ABC):
@abstractmethod
def _to_inner(self):
raise NotImplementedError()

@dataclass
class Overwrite(BaseOperation):
new_schema: pa.Schema
fragments: Iterable[FragmentMetadata]

def __post_init__(self):
if not isinstance(self.new_schema, pa.Schema):
raise TypeError(
f"schema must be pyarrow.Schema, got {type(self.new_schema)}"
)
LanceOperation._validate_fragments(self.fragments)

def _to_inner(self):
raw_fragments = [f._metadata for f in self.fragments]
return _Operation.overwrite(self.new_schema, raw_fragments)

@dataclass
class Append(BaseOperation):
fragments: Iterable[FragmentMetadata]

def __post_init__(self):
LanceOperation._validate_fragments(self.fragments)

def _to_inner(self):
raw_fragments = [f._metadata for f in self.fragments]
return _Operation.append(raw_fragments)

@dataclass
class Delete(BaseOperation):
updated_fragments: Iterable[FragmentMetadata]
deleted_fragment_ids: Iterable[int]
predicate: str

def __post_init__(self):
LanceOperation._validate_fragments(self.updated_fragments)

def _to_inner(self):
raw_updated_fragments = [f._metadata for f in self.updated_fragments]
return _Operation.delete(
raw_updated_fragments, self.deleted_fragment_ids, self.predicate
)

@dataclass
class Rewrite(BaseOperation):
old_fragments: Iterable[FragmentMetadata]
new_fragments: Iterable[FragmentMetadata]

def __post_init__(self):
LanceOperation._validate_fragments(self.old_fragments)
LanceOperation._validate_fragments(self.new_fragments)

def _to_inner(self):
raw_old_fragments = [f._metadata for f in self.old_fragments]
raw_new_fragments = [f._metadata for f in self.new_fragments]
return _Operation.rewrite(raw_old_fragments, raw_new_fragments)

@dataclass
class Merge(BaseOperation):
fragments: Iterable[FragmentMetadata]
schema: pa.Schema

def __post_init__(self):
LanceOperation._validate_fragments(self.fragments)

def _to_inner(self):
raw_fragments = [f._metadata for f in self.fragments]
return _Operation.merge(raw_fragments, self.schema)

@dataclass
class Restore(BaseOperation):
version: int

def _to_inner(self):
return _Operation.restore(self.version)


class ScannerBuilder:
Expand Down Expand Up @@ -991,6 +1098,9 @@ def write_dataset(
The max number of rows to write before starting a new file
max_rows_per_group: int, default 1024
The max number of rows before starting a new group (in the same file)
commit_lock : CommitLock, optional
A custom commit lock. Only needed if your object store does not support
atomic commits. See the user guide for more details.
"""
reader = _coerce_reader(data_obj, schema)
Expand Down
2 changes: 1 addition & 1 deletion python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,4 +347,4 @@ def metadata(self) -> FragmentMetadata:
FragmentMetadata
"""

return FragmentMetadata(self._fragment.metadata())
return FragmentMetadata(self._fragment.metadata().json())
123 changes: 118 additions & 5 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@ def adder(batch: pa.RecordBatch) -> pa.RecordBatch:
fragment_metadata = fragment.add_columns(adder, columns=["a"])
schema = dataset.schema.append(pa.field("c", pa.int64()))

dataset = lance.LanceDataset._commit(base_dir, schema, [fragment_metadata])
operation = lance.LanceOperation.Overwrite(schema, [fragment_metadata])
dataset = lance.LanceDataset._commit(base_dir, operation)
assert dataset.schema == schema

tbl = dataset.to_table()
Expand All @@ -371,16 +372,126 @@ def adder(batch: pa.RecordBatch) -> pa.RecordBatch:
)


def test_create_from_fragments(tmp_path: Path):
def test_create_from_commit(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"
fragment = lance.fragment.LanceFragment.create(base_dir, table)

dataset = lance.LanceDataset._commit(base_dir, table.schema, [fragment])
operation = lance.LanceOperation.Overwrite(table.schema, [fragment])
dataset = lance.LanceDataset._commit(base_dir, operation)
tbl = dataset.to_table()
assert tbl == table


def test_append_with_commit(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"

lance.write_dataset(table, base_dir)

fragment = lance.fragment.LanceFragment.create(base_dir, table)
append = lance.LanceOperation.Append([fragment])

with pytest.raises(OSError):
# Must specify read version
dataset = lance.LanceDataset._commit(base_dir, append)

dataset = lance.LanceDataset._commit(base_dir, append, read_version=1)

tbl = dataset.to_table()

expected = pa.Table.from_pydict(
{
"a": list(range(100)) + list(range(100)),
"b": list(range(100)) + list(range(100)),
}
)
assert tbl == expected


def test_delete_with_commit(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"

lance.write_dataset(table, base_dir)
lance.write_dataset(table, base_dir, mode="append")

half_table = pa.Table.from_pydict({"a": range(50), "b": range(50)})

fragments = lance.dataset(base_dir).get_fragments()

updated_fragment = fragments[0].delete("a >= 50")
delete = lance.LanceOperation.Delete(
[updated_fragment], [fragments[1].fragment_id], "hello"
)

dataset = lance.LanceDataset._commit(base_dir, delete, read_version=2)

tbl = dataset.to_table()
assert tbl == half_table


def test_rewrite_with_commit(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"

lance.write_dataset(table, base_dir)
lance.write_dataset(table, base_dir, mode="append")

combined = pa.Table.from_pydict(
{
"a": list(range(100)) + list(range(100)),
"b": list(range(100)) + list(range(100)),
}
)

to_be_rewrote = [lf.metadata for lf in lance.dataset(base_dir).get_fragments()]

fragment = lance.fragment.LanceFragment.create(base_dir, combined)
rewrite = lance.LanceOperation.Rewrite(to_be_rewrote, [fragment])

dataset = lance.LanceDataset._commit(base_dir, rewrite, read_version=1)

tbl = dataset.to_table()
assert tbl == combined

assert len(lance.dataset(base_dir).get_fragments()) == 1


def test_restore_with_commit(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"

lance.write_dataset(table, base_dir)
lance.write_dataset(table, base_dir, mode="append")

restore = lance.LanceOperation.Restore(1)
dataset = lance.LanceDataset._commit(base_dir, restore)

tbl = dataset.to_table()
assert tbl == table


def test_merge_with_commit(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"

lance.write_dataset(table, base_dir)

fragment = lance.dataset(base_dir).get_fragments()[0]
merged = fragment.add_columns(
lambda _: pa.RecordBatch.from_pydict({"c": range(100)})
)

expected = pa.Table.from_pydict({"a": range(100), "b": range(100), "c": range(100)})

merge = lance.LanceOperation.Merge([merged], expected.schema)
dataset = lance.LanceDataset._commit(base_dir, merge, read_version=1)

tbl = dataset.to_table()
assert tbl == expected


def test_data_files(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
base_dir = tmp_path / "test"
Expand Down Expand Up @@ -411,7 +522,8 @@ def test_deletion_file(tmp_path: Path):
assert new_fragment.deletion_file() is not None
assert re.match("_deletions/0-1-[0-9]{1,32}.arrow", new_fragment.deletion_file())
print(type(new_fragment), new_fragment)
dataset = lance.LanceDataset._commit(base_dir, table.schema, [new_fragment])
operation = lance.LanceOperation.Overwrite(table.schema, [new_fragment])
dataset = lance.LanceDataset._commit(base_dir, operation)
assert dataset.count_rows() == 90


Expand All @@ -429,7 +541,8 @@ def test_commit_fragments_via_scanner(tmp_path: Path):
unpickled = pickle.loads(pickled)
assert fragment_metadata == unpickled

dataset = lance.LanceDataset._commit(base_dir, table.schema, [fragment_metadata])
operation = lance.LanceOperation.Overwrite(table.schema, [fragment_metadata])
dataset = lance.LanceDataset._commit(base_dir, operation)
assert dataset.schema == table.schema

tbl = dataset.to_table()
Expand Down
6 changes: 4 additions & 2 deletions python/python/tests/test_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pandas as pd
import pyarrow as pa
import pytest
from lance import FragmentMetadata, LanceDataset, LanceFragment
from lance import FragmentMetadata, LanceDataset, LanceFragment, LanceOperation
from lance.progress import FileSystemFragmentWriteProgress, FragmentWriteProgress


Expand Down Expand Up @@ -50,7 +50,9 @@ def test_write_fragment_two_phases(tmp_path: Path):
fragments = [FragmentMetadata.from_json(j) for j in json_array]

schema = pa.schema([pa.field("a", pa.int64())])
dataset = LanceDataset._commit(tmp_path, schema, fragments)

operation = LanceOperation.Overwrite(schema, fragments)
dataset = LanceDataset._commit(tmp_path, operation)

df = dataset.to_table().to_pandas()
pd.testing.assert_frame_equal(
Expand Down
Loading

0 comments on commit 8c46445

Please sign in to comment.