From a98356223699788a71c06d9202369fee476ecd87 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 29 Aug 2022 22:15:56 -0700 Subject: [PATCH 01/12] add pytorch dataset --- python/lance/pytorch/__init__.py | 14 +++++++ python/lance/pytorch/data.py | 65 ++++++++++++++++++++++++++++++ python/lance/tests/test_pytorch.py | 41 +++++++++++++++++++ python/setup.py | 5 +-- 4 files changed, 122 insertions(+), 3 deletions(-) create mode 100644 python/lance/pytorch/__init__.py create mode 100644 python/lance/pytorch/data.py create mode 100644 python/lance/tests/test_pytorch.py diff --git a/python/lance/pytorch/__init__.py b/python/lance/pytorch/__init__.py new file mode 100644 index 0000000000..fe547f5cb4 --- /dev/null +++ b/python/lance/pytorch/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2022 Lance Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/python/lance/pytorch/data.py b/python/lance/pytorch/data.py new file mode 100644 index 0000000000..9766cea729 --- /dev/null +++ b/python/lance/pytorch/data.py @@ -0,0 +1,65 @@ +# Copyright 2022 Lance Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Union, Optional + +import pyarrow as pa +import pyarrow.dataset + +try: + from torch.utils.data import IterableDataset +except ImportError: + raise ImportError( + "Please install pytorch via pip install lance[pytorch]" + ) + +from lance import dataset + +__all__ = ["LanceDataset"] + + +class LanceDataset(IterableDataset): + """An PyTorch IterableDataset. + + See: + https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset + """ + + def __init__( + self, + uri: Union[str, Path], + columns: Optional[list[str]] = None, + batch_size: int = 64, + ): + self.uri = uri + self.columns = columns if columns else [] + self.batch_size = batch_size + self.scanner: pa.dataset.Scanner = dataset(self.uri).scanner( + columns=columns, batch_size=batch_size + ) + + def __repr__(self): + return f"LanceDataset(uri={self.uri})" + + def __iter__(self): + """Yield dataset""" + import torch + + for batch in self.scanner.scan_batches(): + # TODO: arrow.to_numpy(writable=True) makes a new copy of data. + yield [ + torch.from_numpy(arr.to_numpy(zero_copy_only=False, writable=True)) + for arr in batch.record_batch.columns + ] diff --git a/python/lance/tests/test_pytorch.py b/python/lance/tests/test_pytorch.py new file mode 100644 index 0000000000..e312547ca7 --- /dev/null +++ b/python/lance/tests/test_pytorch.py @@ -0,0 +1,41 @@ +# Copyright 2022 Lance Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import lance +import pytest + +pytest.importorskip("torch") + +import torch +import pyarrow as pa +from pathlib import Path +import pandas as pd +from torch.utils.data import DataLoader + +from lance.pytorch.data import LanceDataset + + +def test_data_loader(tmp_path: Path): + torch.Tensor([1, 2, 3]) + ids = pa.array(range(10)) + values = pa.array(range(10, 20)) + tab = pa.Table.from_arrays([ids, values], names=["id", "value"]) + print(tab) + print(tab.take([0])) + print(ids) + + lance.write_table(tab, tmp_path / "lance") + + dataset = LanceDataset(tmp_path / "lance", batch_size=4) + print(dataset) + print(next(iter(dataset))) diff --git a/python/setup.py b/python/setup.py index 6821fc8ed8..f445114eb9 100644 --- a/python/setup.py +++ b/python/setup.py @@ -32,7 +32,6 @@ arrow_library_dirs = pa.get_library_dirs() numpy_includes = np.get_include() - # TODO allow for custom liblance directory lance_cpp = Path(__file__).resolve().parent.parent / "cpp" lance_includes = str(lance_cpp / "include") @@ -51,7 +50,6 @@ ) ] - this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() @@ -69,7 +67,8 @@ ext_modules=cythonize(extensions, language_level="3"), zip_safe=False, install_requires=["pyarrow>=9,<10"], - extras_require={"test": ["pytest>=6.0", "pandas", "duckdb", "click"]}, + extras_require={"test": ["pytest>=6.0", "pandas", "duckdb", "click"], + "pytorch": ["pytorch"]}, python_requires=">=3.8", packages=find_packages(), classifiers=[ From 43f6a2759d93d8a0cd8a090600c75a2e7cfd1bde Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 29 Aug 2022 22:18:26 -0700 Subject: [PATCH 02/12] add pytorch dataset --- python/lance/tests/test_pytorch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/lance/tests/test_pytorch.py b/python/lance/tests/test_pytorch.py index e312547ca7..02dea95469 100644 --- a/python/lance/tests/test_pytorch.py +++ b/python/lance/tests/test_pytorch.py @@ -38,4 +38,7 @@ def test_data_loader(tmp_path: Path): dataset = LanceDataset(tmp_path / "lance", batch_size=4) print(dataset) - print(next(iter(dataset))) + id_batch, value_batch = next(iter(dataset)) + assert(id_batch.shape == (1, 4)) + + From a100680dae572fe818712e859f4f2b19ed3d9e8e Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 30 Aug 2022 10:55:35 -0700 Subject: [PATCH 03/12] pass batch size via scanner --- python/lance/__init__.py | 9 ++++++++- python/lance/_lib.pyx | 5 ++++- python/lance/pytorch/__init__.py | 1 - python/lance/pytorch/data.py | 13 ++++++------- python/lance/tests/test_api.py | 8 +------- python/lance/tests/test_pytorch.py | 4 +--- 6 files changed, 20 insertions(+), 20 deletions(-) diff --git a/python/lance/__init__.py b/python/lance/__init__.py index 5ace46d79b..2b7b4072c1 100644 --- a/python/lance/__init__.py +++ b/python/lance/__init__.py @@ -49,13 +49,20 @@ def scanner( data: Union[str, Path, ds.Dataset], columns: Optional[str] = None, filter: Optional[pc.Expression] = None, + batch_size: Optional[int] = None, limit: Optional[int] = None, offset: int = 0, ) -> ds.Scanner: if isinstance(data, (str, Path)): data = dataset(str(data)) + print(batch_size) return BuildScanner( - data, columns=columns, filter=filter, limit=limit, offset=offset + data, + columns=columns, + filter=filter, + batch_size=batch_size, + limit=limit, + offset=offset, ) diff --git a/python/lance/_lib.pyx b/python/lance/_lib.pyx index 750307e69e..a01b2bc2c3 100644 --- a/python/lance/_lib.pyx +++ b/python/lance/_lib.pyx @@ -77,6 +77,7 @@ cdef extern from "lance/arrow/scanner.h" namespace "lance::arrow" nogil: LScannerBuilder(shared_ptr[CDataset]) except + void Project(const vector[string]& columns) void Filter(CExpression filter) + CStatus BatchSize(int64_t batch_size) void Limit(int64_t limit, int64_t offset) CResult[shared_ptr[CScanner]] Finish() @@ -84,6 +85,7 @@ def BuildScanner( dataset: Dataset, columns: Optional[list[str]] = None, filter: Optional[Expression] = None, + batch_size: Optional[int] = None, limit: Optional[int] = None, offset: int = 0, ): @@ -95,9 +97,10 @@ def BuildScanner( builder.get().Project([tobytes(c) for c in columns]) if filter is not None: builder.get().Filter(_bind(filter, dataset.schema)) + if batch_size is not None: + builder.get().BatchSize(batch_size) if limit is not None: builder.get().Limit(limit, offset) - scanner = GetResultValue(builder.get().Finish()) creader = GetResultValue(scanner.get().ToRecordBatchReader()) reader = RecordBatchReader() diff --git a/python/lance/pytorch/__init__.py b/python/lance/pytorch/__init__.py index fe547f5cb4..37ce02c89e 100644 --- a/python/lance/pytorch/__init__.py +++ b/python/lance/pytorch/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/python/lance/pytorch/data.py b/python/lance/pytorch/data.py index 9766cea729..cbb228f660 100644 --- a/python/lance/pytorch/data.py +++ b/python/lance/pytorch/data.py @@ -21,9 +21,7 @@ try: from torch.utils.data import IterableDataset except ImportError: - raise ImportError( - "Please install pytorch via pip install lance[pytorch]" - ) + raise ImportError("Please install pytorch via pip install lance[pytorch]") from lance import dataset @@ -38,14 +36,15 @@ class LanceDataset(IterableDataset): """ def __init__( - self, - uri: Union[str, Path], - columns: Optional[list[str]] = None, - batch_size: int = 64, + self, + uri: Union[str, Path], + columns: Optional[list[str]] = None, + batch_size: Optional[int] = None, ): self.uri = uri self.columns = columns if columns else [] self.batch_size = batch_size + print(f"Init lance dataset: batch size={batch_size}") self.scanner: pa.dataset.Scanner = dataset(self.uri).scanner( columns=columns, batch_size=batch_size ) diff --git a/python/lance/tests/test_api.py b/python/lance/tests/test_api.py index 47aaa65745..c3b745118b 100644 --- a/python/lance/tests/test_api.py +++ b/python/lance/tests/test_api.py @@ -48,7 +48,6 @@ def test_write_categorical_values(tmp_path: Path): assert table == actual - def test_write_dataset(tmp_path: Path): table = pa.Table.from_pandas( pd.DataFrame( @@ -59,12 +58,7 @@ def test_write_dataset(tmp_path: Path): } ) ) - ds.write_dataset( - table, - tmp_path, - partitioning=["split"], - format=LanceFileFormat() - ) + ds.write_dataset(table, tmp_path, partitioning=["split"], format=LanceFileFormat()) part_dirs = [d.name for d in tmp_path.iterdir()] assert set(part_dirs) == set(["a", "b"]) diff --git a/python/lance/tests/test_pytorch.py b/python/lance/tests/test_pytorch.py index 02dea95469..22bf2ff2f0 100644 --- a/python/lance/tests/test_pytorch.py +++ b/python/lance/tests/test_pytorch.py @@ -39,6 +39,4 @@ def test_data_loader(tmp_path: Path): dataset = LanceDataset(tmp_path / "lance", batch_size=4) print(dataset) id_batch, value_batch = next(iter(dataset)) - assert(id_batch.shape == (1, 4)) - - + assert id_batch.shape == (1, 4) From 0f6ae9ef7c1cdbff62464d82c504607f298c5243 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 30 Aug 2022 13:10:32 -0700 Subject: [PATCH 04/12] pritn debug --- python/lance/pytorch/data.py | 6 +++--- python/lance/tests/test_pytorch.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/lance/pytorch/data.py b/python/lance/pytorch/data.py index cbb228f660..f8aa7075f9 100644 --- a/python/lance/pytorch/data.py +++ b/python/lance/pytorch/data.py @@ -23,7 +23,7 @@ except ImportError: raise ImportError("Please install pytorch via pip install lance[pytorch]") -from lance import dataset +from lance import dataset, scanner __all__ = ["LanceDataset"] @@ -45,8 +45,8 @@ def __init__( self.columns = columns if columns else [] self.batch_size = batch_size print(f"Init lance dataset: batch size={batch_size}") - self.scanner: pa.dataset.Scanner = dataset(self.uri).scanner( - columns=columns, batch_size=batch_size + self.scanner: pa.dataset.Scanner = scanner( + dataset(self.uri), columns=columns, batch_size=batch_size ) def __repr__(self): diff --git a/python/lance/tests/test_pytorch.py b/python/lance/tests/test_pytorch.py index 22bf2ff2f0..96490ef8ca 100644 --- a/python/lance/tests/test_pytorch.py +++ b/python/lance/tests/test_pytorch.py @@ -39,4 +39,4 @@ def test_data_loader(tmp_path: Path): dataset = LanceDataset(tmp_path / "lance", batch_size=4) print(dataset) id_batch, value_batch = next(iter(dataset)) - assert id_batch.shape == (1, 4) + assert id_batch.shape == 4 From d4f8f6fbce6fa092e7dbc0b76effc7e4c10fc065 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 30 Aug 2022 15:57:26 -0700 Subject: [PATCH 05/12] change to use record reader --- python/lance/pytorch/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/lance/pytorch/data.py b/python/lance/pytorch/data.py index f8aa7075f9..8bda570f2c 100644 --- a/python/lance/pytorch/data.py +++ b/python/lance/pytorch/data.py @@ -56,9 +56,9 @@ def __iter__(self): """Yield dataset""" import torch - for batch in self.scanner.scan_batches(): + for batch in self.scanner.to_reader(): # TODO: arrow.to_numpy(writable=True) makes a new copy of data. yield [ torch.from_numpy(arr.to_numpy(zero_copy_only=False, writable=True)) - for arr in batch.record_batch.columns + for arr in batch.columns ] From d1967813dd55bd13b1406b584bfd1ed27d821c9d Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 5 Sep 2022 11:55:53 -0700 Subject: [PATCH 06/12] fix tests --- python/lance/_lib.pyx | 6 +++--- python/lance/pytorch/data.py | 1 - python/lance/tests/test_pytorch.py | 8 +------- python/setup.py | 2 +- 4 files changed, 5 insertions(+), 12 deletions(-) diff --git a/python/lance/_lib.pyx b/python/lance/_lib.pyx index a01b2bc2c3..a326f13edc 100644 --- a/python/lance/_lib.pyx +++ b/python/lance/_lib.pyx @@ -75,10 +75,10 @@ cdef extern from "lance/arrow/writer.h" namespace "lance::arrow" nogil: cdef extern from "lance/arrow/scanner.h" namespace "lance::arrow" nogil: cdef cppclass LScannerBuilder "::lance::arrow::ScannerBuilder": LScannerBuilder(shared_ptr[CDataset]) except + - void Project(const vector[string]& columns) - void Filter(CExpression filter) + CStatus Project(const vector[string]& columns) + CStatus Filter(CExpression filter) CStatus BatchSize(int64_t batch_size) - void Limit(int64_t limit, int64_t offset) + CStatus Limit(int64_t limit, int64_t offset) CResult[shared_ptr[CScanner]] Finish() def BuildScanner( diff --git a/python/lance/pytorch/data.py b/python/lance/pytorch/data.py index 8bda570f2c..1bc3d0385a 100644 --- a/python/lance/pytorch/data.py +++ b/python/lance/pytorch/data.py @@ -44,7 +44,6 @@ def __init__( self.uri = uri self.columns = columns if columns else [] self.batch_size = batch_size - print(f"Init lance dataset: batch size={batch_size}") self.scanner: pa.dataset.Scanner = scanner( dataset(self.uri), columns=columns, batch_size=batch_size ) diff --git a/python/lance/tests/test_pytorch.py b/python/lance/tests/test_pytorch.py index 96490ef8ca..0aa2c0ffad 100644 --- a/python/lance/tests/test_pytorch.py +++ b/python/lance/tests/test_pytorch.py @@ -19,8 +19,6 @@ import torch import pyarrow as pa from pathlib import Path -import pandas as pd -from torch.utils.data import DataLoader from lance.pytorch.data import LanceDataset @@ -30,13 +28,9 @@ def test_data_loader(tmp_path: Path): ids = pa.array(range(10)) values = pa.array(range(10, 20)) tab = pa.Table.from_arrays([ids, values], names=["id", "value"]) - print(tab) - print(tab.take([0])) - print(ids) lance.write_table(tab, tmp_path / "lance") dataset = LanceDataset(tmp_path / "lance", batch_size=4) - print(dataset) id_batch, value_batch = next(iter(dataset)) - assert id_batch.shape == 4 + assert id_batch.shape == torch.Size([4]) diff --git a/python/setup.py b/python/setup.py index f445114eb9..e40e1a5c21 100644 --- a/python/setup.py +++ b/python/setup.py @@ -68,7 +68,7 @@ zip_safe=False, install_requires=["pyarrow>=9,<10"], extras_require={"test": ["pytest>=6.0", "pandas", "duckdb", "click"], - "pytorch": ["pytorch"]}, + "pytorch": ["torch"]}, python_requires=">=3.8", packages=find_packages(), classifiers=[ From 951d9a783e3b79d20cbc97bcdbb4484b70e0c17d Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 5 Sep 2022 11:56:55 -0700 Subject: [PATCH 07/12] install pytorch on github action --- .github/workflows/python.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index b6bf8fe4e9..7e1222190a 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -31,7 +31,7 @@ jobs: run: ./tools/build_wheel.sh $(echo cp${{ matrix.python-version }} | sed "s/\.//") - name: Pip install run: | - python -m pip install $(ls wheels/*.whl)[test] + python -m pip install $(ls wheels/*.whl)[test,pytorch] - name: Run python tests run: | pytest -x -v --durations=10 From bf5eb5bb00b35a54b925c375989041fd5285ed3a Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 5 Sep 2022 13:10:51 -0700 Subject: [PATCH 08/12] fix py38 compatiblity --- python/lance/__init__.py | 1 - python/lance/_lib.pyx | 13 +++++++++++-- python/lance/pytorch/data.py | 4 ++-- python/lance/tests/test_api.py | 4 ++-- python/lance/tests/test_pytorch.py | 8 +++++--- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/python/lance/__init__.py b/python/lance/__init__.py index 2b7b4072c1..b85c1e0520 100644 --- a/python/lance/__init__.py +++ b/python/lance/__init__.py @@ -55,7 +55,6 @@ def scanner( ) -> ds.Scanner: if isinstance(data, (str, Path)): data = dataset(str(data)) - print(batch_size) return BuildScanner( data, columns=columns, diff --git a/python/lance/_lib.pyx b/python/lance/_lib.pyx index a326f13edc..d376d1315c 100644 --- a/python/lance/_lib.pyx +++ b/python/lance/_lib.pyx @@ -4,13 +4,22 @@ from typing import Optional, Union from cython.operator cimport dereference as deref from libcpp cimport bool -from libcpp.memory cimport shared_ptr, const_pointer_cast +from libcpp.memory cimport const_pointer_cast, shared_ptr from libcpp.string cimport string from pathlib import Path from pyarrow import Table -from pyarrow._dataset cimport FileFormat, FileWriteOptions, CFileWriteOptions, CScanner, CDataset, Dataset + +from pyarrow._dataset cimport ( + CDataset, + CFileWriteOptions, + CScanner, + Dataset, + FileFormat, + FileWriteOptions, +) + from pyarrow._dataset import Scanner from pyarrow._compute cimport Expression, _bind diff --git a/python/lance/pytorch/data.py b/python/lance/pytorch/data.py index 1bc3d0385a..57fd5f040f 100644 --- a/python/lance/pytorch/data.py +++ b/python/lance/pytorch/data.py @@ -13,7 +13,7 @@ # limitations under the License. from pathlib import Path -from typing import Union, Optional +from typing import List, Optional, Union import pyarrow as pa import pyarrow.dataset @@ -38,7 +38,7 @@ class LanceDataset(IterableDataset): def __init__( self, uri: Union[str, Path], - columns: Optional[list[str]] = None, + columns: Optional[List[str]] = None, batch_size: Optional[int] = None, ): self.uri = uri diff --git a/python/lance/tests/test_api.py b/python/lance/tests/test_api.py index c3b745118b..1a4cc10db2 100644 --- a/python/lance/tests/test_api.py +++ b/python/lance/tests/test_api.py @@ -16,10 +16,10 @@ from pathlib import Path import pandas as pd - import pyarrow as pa import pyarrow.dataset as ds -from lance import write_table, dataset, LanceFileFormat + +from lance import LanceFileFormat, dataset, write_table def test_simple_round_trips(tmp_path: Path): diff --git a/python/lance/tests/test_pytorch.py b/python/lance/tests/test_pytorch.py index 0aa2c0ffad..d181811572 100644 --- a/python/lance/tests/test_pytorch.py +++ b/python/lance/tests/test_pytorch.py @@ -11,15 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import lance + import pytest pytest.importorskip("torch") -import torch -import pyarrow as pa from pathlib import Path +import pyarrow as pa +import torch + +import lance from lance.pytorch.data import LanceDataset From 2c248cc807a0ca3f133b5baf3fad222392685b45 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 5 Sep 2022 13:17:28 -0700 Subject: [PATCH 09/12] cleanup --- python/lance/pytorch/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/lance/pytorch/data.py b/python/lance/pytorch/data.py index 57fd5f040f..422c73244f 100644 --- a/python/lance/pytorch/data.py +++ b/python/lance/pytorch/data.py @@ -19,6 +19,7 @@ import pyarrow.dataset try: + import torch from torch.utils.data import IterableDataset except ImportError: raise ImportError("Please install pytorch via pip install lance[pytorch]") @@ -53,10 +54,9 @@ def __repr__(self): def __iter__(self): """Yield dataset""" - import torch - for batch in self.scanner.to_reader(): # TODO: arrow.to_numpy(writable=True) makes a new copy of data. + # Investigate how to directly perform zero-copy into Torch Tensor. yield [ torch.from_numpy(arr.to_numpy(zero_copy_only=False, writable=True)) for arr in batch.columns From d78b50ca7bb2c95504eda056f004f49b1d76fa8e Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 5 Sep 2022 22:01:41 -0700 Subject: [PATCH 10/12] address comment --- python/lance/__init__.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/python/lance/__init__.py b/python/lance/__init__.py index b85c1e0520..bb3984d368 100644 --- a/python/lance/__init__.py +++ b/python/lance/__init__.py @@ -14,7 +14,7 @@ import platform from pathlib import Path -from typing import Optional, Union +from typing import List, Optional, Union import pyarrow as pa import pyarrow.compute as pc @@ -47,12 +47,35 @@ def dataset( def scanner( data: Union[str, Path, ds.Dataset], - columns: Optional[str] = None, + columns: Optional[List[str]] = None, filter: Optional[pc.Expression] = None, batch_size: Optional[int] = None, limit: Optional[int] = None, offset: int = 0, ) -> ds.Scanner: + """Build a PyArrow Dataset scanner. + + It extends PyArrow Scanner with limit pushdown. + + Parameters + ---------- + data: uri, path or pyarrow dataset + The Dataset + columns: List[str], optional + Specify the columns to read. + filter: pc.Expression, optional + Apply filter to the scanner. + batch_size: int + The maximum number of records to scan for each batch. + limit: int + Limit the number of records to return in total. + offset: int + The offset to read the data from. + + See Also + -------- + https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner + """ if isinstance(data, (str, Path)): data = dataset(str(data)) return BuildScanner( From 5ebdb62c828a6ec574204211ad102cde7f66bcc6 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 5 Sep 2022 22:21:36 -0700 Subject: [PATCH 11/12] address comments --- python/lance/pytorch/data.py | 4 ++-- python/lance/tests/test_pytorch.py | 3 +-- python/setup.py | 7 +++---- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/python/lance/pytorch/data.py b/python/lance/pytorch/data.py index 422c73244f..796c11dfbb 100644 --- a/python/lance/pytorch/data.py +++ b/python/lance/pytorch/data.py @@ -21,8 +21,8 @@ try: import torch from torch.utils.data import IterableDataset -except ImportError: - raise ImportError("Please install pytorch via pip install lance[pytorch]") +except ImportError as e: + raise ImportError("Please install pytorch", e) from lance import dataset, scanner diff --git a/python/lance/tests/test_pytorch.py b/python/lance/tests/test_pytorch.py index d181811572..de030f6a3a 100644 --- a/python/lance/tests/test_pytorch.py +++ b/python/lance/tests/test_pytorch.py @@ -14,12 +14,11 @@ import pytest -pytest.importorskip("torch") +torch = pytest.importorskip("torch") from pathlib import Path import pyarrow as pa -import torch import lance from lance.pytorch.data import LanceDataset diff --git a/python/setup.py b/python/setup.py index e40e1a5c21..6574c16d79 100644 --- a/python/setup.py +++ b/python/setup.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path -from setuptools import Extension, find_packages, setup import platform +from pathlib import Path import numpy as np import pyarrow as pa from Cython.Build import cythonize +from setuptools import Extension, find_packages, setup extra_libs = [] # TODO: ciwheelbuild can not find / dont need arrow_python. @@ -67,8 +67,7 @@ ext_modules=cythonize(extensions, language_level="3"), zip_safe=False, install_requires=["pyarrow>=9,<10"], - extras_require={"test": ["pytest>=6.0", "pandas", "duckdb", "click"], - "pytorch": ["torch"]}, + extras_require={"test": ["pytest>=6.0", "pandas", "duckdb", "click"]}, python_requires=">=3.8", packages=find_packages(), classifiers=[ From 8ace2a9cd696762ffaf65baf047c27c5d1549f1e Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Wed, 7 Sep 2022 10:23:44 -0700 Subject: [PATCH 12/12] install pytorch manually in GHA --- .github/workflows/python.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 7e1222190a..50c20ce2d7 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -31,7 +31,8 @@ jobs: run: ./tools/build_wheel.sh $(echo cp${{ matrix.python-version }} | sed "s/\.//") - name: Pip install run: | - python -m pip install $(ls wheels/*.whl)[test,pytorch] + python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu + python -m pip install $(ls wheels/*.whl)[test] - name: Run python tests run: | pytest -x -v --durations=10