Skip to content

Commit

Permalink
[Python] Expose projection for append column (#325)
Browse files Browse the repository at this point in the history
Pass columns python and add tests
  • Loading branch information
eddyxu authored Nov 22, 2022
1 parent 1de9d89 commit c4adcc3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
17 changes: 13 additions & 4 deletions python/lance/_lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ cdef extern from "lance/arrow/updater.h" namespace "lance::arrow" nogil:
CResult[shared_ptr[CLanceDataset]] Finish();

cdef cppclass CUpdaterBuilder "::lance::arrow::UpdaterBuilder":
void Project(vector[string] columns);

CResult[shared_ptr[CUpdater]] Finish();


Expand Down Expand Up @@ -283,9 +285,10 @@ cdef class FileSystemDataset(Dataset):
return FileSystemDataset.wrap(static_pointer_cast[CDataset, CLanceDataset](dataset))

def append_column(
self,
field: Field,
value: Union[Callable[[pyarrow.Table], pyarrow.Array], Expression],
self,
field: Field,
value: Union[Callable[[pyarrow.Table], pyarrow.Array], Expression],
columns: Optional[List[str]] = None,
) -> FileSystemDataset:
"""Append a new column.

Expand All @@ -296,16 +299,22 @@ cdef class FileSystemDataset(Dataset):
value : Callback[[pyarrow.Table], pyarrow.Array], pyarrow.compute.Expression
A function / callback that takes in a Batch and produces an Array. The generated array must
have the same length as the input batch.
columns : list of strs, optional.
The list of columns to read from the source dataset.
"""
cdef:
shared_ptr[CUpdater] c_updater
shared_ptr[CField] c_field
shared_ptr[CUpdaterBuilder] c_update_builder

if isinstance(value, Expression):
return self._append_column_expr(field, value)
elif isinstance(value, Callable):
c_field = pyarrow_unwrap_field(field)
c_updater = move(GetResultValue(GetResultValue(move(self.lance_dataset.NewUpdate(c_field))).get().Finish()))
c_update_builder = GetResultValue(self.lance_dataset.NewUpdate(c_field))
if columns is not None and len(columns) > 0:
c_update_builder.get().Project([tobytes(col) for col in columns])
c_updater = GetResultValue(c_update_builder.get().Finish())
updater = Updater.wrap(c_updater)
for table in updater:
arr = value(table)
Expand Down
19 changes: 19 additions & 0 deletions python/lance/tests/test_schema_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,25 @@ def test_write_versioned_dataset(tmp_path: Path):
pd.testing.assert_frame_equal(expected_df, actual_df)


def test_column_projection(tmp_path: Path):
table1 = pa.Table.from_pylist([{"a": 1, "b": 2}, {"a": 10, "b": 20}])
base_dir = tmp_path / "test"
lance.write_dataset(table1, base_dir)

dataset = lance.dataset(base_dir)

def value_func(x: pa.Table):
assert x.num_columns == 1
assert x.column_names == ["a"]
return pa.array([str(i) for i in x.column("a")])

new_dataset = dataset.append_column(pa.field("c", pa.utf8()), value_func, columns=["a"])

actual_df = new_dataset.to_table().to_pandas()
expected_df = pd.DataFrame({"a": [1, 10], "b": [2, 20], "c": ["1", "10"]})
pd.testing.assert_frame_equal(expected_df, actual_df)


def test_add_column_with_literal(tmp_path: Path):
table = pa.Table.from_pylist([{"a": i} for i in range(10)])

Expand Down

0 comments on commit c4adcc3

Please sign in to comment.