diff --git a/python/lance/_lib.pyx b/python/lance/_lib.pyx index 84d7507c48..3a0cd09be4 100644 --- a/python/lance/_lib.pyx +++ b/python/lance/_lib.pyx @@ -155,6 +155,8 @@ cdef extern from "lance/arrow/updater.h" namespace "lance::arrow" nogil: CResult[shared_ptr[CLanceDataset]] Finish(); cdef cppclass CUpdaterBuilder "::lance::arrow::UpdaterBuilder": + void Project(vector[string] columns); + CResult[shared_ptr[CUpdater]] Finish(); @@ -283,9 +285,10 @@ cdef class FileSystemDataset(Dataset): return FileSystemDataset.wrap(static_pointer_cast[CDataset, CLanceDataset](dataset)) def append_column( - self, - field: Field, - value: Union[Callable[[pyarrow.Table], pyarrow.Array], Expression], + self, + field: Field, + value: Union[Callable[[pyarrow.Table], pyarrow.Array], Expression], + columns: Optional[List[str]] = None, ) -> FileSystemDataset: """Append a new column. @@ -296,16 +299,22 @@ cdef class FileSystemDataset(Dataset): value : Callback[[pyarrow.Table], pyarrow.Array], pyarrow.compute.Expression A function / callback that takes in a Batch and produces an Array. The generated array must have the same length as the input batch. + columns : list of strs, optional. + The list of columns to read from the source dataset. """ cdef: shared_ptr[CUpdater] c_updater shared_ptr[CField] c_field + shared_ptr[CUpdaterBuilder] c_update_builder if isinstance(value, Expression): return self._append_column_expr(field, value) elif isinstance(value, Callable): c_field = pyarrow_unwrap_field(field) - c_updater = move(GetResultValue(GetResultValue(move(self.lance_dataset.NewUpdate(c_field))).get().Finish())) + c_update_builder = GetResultValue(self.lance_dataset.NewUpdate(c_field)) + if columns is not None and len(columns) > 0: + c_update_builder.get().Project([tobytes(col) for col in columns]) + c_updater = GetResultValue(c_update_builder.get().Finish()) updater = Updater.wrap(c_updater) for table in updater: arr = value(table) diff --git a/python/lance/tests/test_schema_evolution.py b/python/lance/tests/test_schema_evolution.py index bbab7fb7f6..0e806f6a89 100644 --- a/python/lance/tests/test_schema_evolution.py +++ b/python/lance/tests/test_schema_evolution.py @@ -34,6 +34,25 @@ def test_write_versioned_dataset(tmp_path: Path): pd.testing.assert_frame_equal(expected_df, actual_df) +def test_column_projection(tmp_path: Path): + table1 = pa.Table.from_pylist([{"a": 1, "b": 2}, {"a": 10, "b": 20}]) + base_dir = tmp_path / "test" + lance.write_dataset(table1, base_dir) + + dataset = lance.dataset(base_dir) + + def value_func(x: pa.Table): + assert x.num_columns == 1 + assert x.column_names == ["a"] + return pa.array([str(i) for i in x.column("a")]) + + new_dataset = dataset.append_column(pa.field("c", pa.utf8()), value_func, columns=["a"]) + + actual_df = new_dataset.to_table().to_pandas() + expected_df = pd.DataFrame({"a": [1, 10], "b": [2, 20], "c": ["1", "10"]}) + pd.testing.assert_frame_equal(expected_df, actual_df) + + def test_add_column_with_literal(tmp_path: Path): table = pa.Table.from_pylist([{"a": i} for i in range(10)])