Skip to content

Commit

Permalink
Initial PyTorch Dataset support (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored Sep 7, 2022
1 parent 294434d commit 3f192d3
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 9 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
run: ./tools/build_wheel.sh $(echo cp${{ matrix.python-version }} | sed "s/\.//")
- name: Pip install
run: |
python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install $(ls wheels/*.whl)[test]
- name: Run python tests
run: |
Expand Down
35 changes: 32 additions & 3 deletions python/lance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import platform
from pathlib import Path
from typing import Optional, Union
from typing import List, Optional, Union

import pyarrow as pa
import pyarrow.compute as pc
Expand Down Expand Up @@ -47,15 +47,44 @@ def dataset(

def scanner(
data: Union[str, Path, ds.Dataset],
columns: Optional[str] = None,
columns: Optional[List[str]] = None,
filter: Optional[pc.Expression] = None,
batch_size: Optional[int] = None,
limit: Optional[int] = None,
offset: int = 0,
) -> ds.Scanner:
"""Build a PyArrow Dataset scanner.
It extends PyArrow Scanner with limit pushdown.
Parameters
----------
data: uri, path or pyarrow dataset
The Dataset
columns: List[str], optional
Specify the columns to read.
filter: pc.Expression, optional
Apply filter to the scanner.
batch_size: int
The maximum number of records to scan for each batch.
limit: int
Limit the number of records to return in total.
offset: int
The offset to read the data from.
See Also
--------
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner
"""
if isinstance(data, (str, Path)):
data = dataset(str(data))
return BuildScanner(
data, columns=columns, filter=filter, limit=limit, offset=offset
data,
columns=columns,
filter=filter,
batch_size=batch_size,
limit=limit,
offset=offset,
)


Expand Down
11 changes: 7 additions & 4 deletions python/lance/_lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,17 @@ cdef extern from "lance/arrow/writer.h" namespace "lance::arrow" nogil:
cdef extern from "lance/arrow/scanner.h" namespace "lance::arrow" nogil:
cdef cppclass LScannerBuilder "::lance::arrow::ScannerBuilder":
LScannerBuilder(shared_ptr[CDataset]) except +
void Project(const vector[string]& columns)
void Filter(CExpression filter)
void Limit(int64_t limit, int64_t offset)
CStatus Project(const vector[string]& columns)
CStatus Filter(CExpression filter)
CStatus BatchSize(int64_t batch_size)
CStatus Limit(int64_t limit, int64_t offset)
CResult[shared_ptr[CScanner]] Finish()

def BuildScanner(
dataset: Dataset,
columns: Optional[list[str]] = None,
filter: Optional[Expression] = None,
batch_size: Optional[int] = None,
limit: Optional[int] = None,
offset: int = 0,
):
Expand All @@ -104,9 +106,10 @@ def BuildScanner(
builder.get().Project([tobytes(c) for c in columns])
if filter is not None:
builder.get().Filter(_bind(filter, dataset.schema))
if batch_size is not None:
builder.get().BatchSize(batch_size)
if limit is not None:
builder.get().Limit(limit, offset)

scanner = GetResultValue(builder.get().Finish())
creader = GetResultValue(scanner.get().ToRecordBatchReader())
reader = RecordBatchReader()
Expand Down
13 changes: 13 additions & 0 deletions python/lance/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022 Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
63 changes: 63 additions & 0 deletions python/lance/pytorch/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2022 Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import List, Optional, Union

import pyarrow as pa
import pyarrow.dataset

try:
import torch
from torch.utils.data import IterableDataset
except ImportError as e:
raise ImportError("Please install pytorch", e)

from lance import dataset, scanner

__all__ = ["LanceDataset"]


class LanceDataset(IterableDataset):
"""An PyTorch IterableDataset.
See:
https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
"""

def __init__(
self,
uri: Union[str, Path],
columns: Optional[List[str]] = None,
batch_size: Optional[int] = None,
):
self.uri = uri
self.columns = columns if columns else []
self.batch_size = batch_size
self.scanner: pa.dataset.Scanner = scanner(
dataset(self.uri), columns=columns, batch_size=batch_size
)

def __repr__(self):
return f"LanceDataset(uri={self.uri})"

def __iter__(self):
"""Yield dataset"""
for batch in self.scanner.to_reader():
# TODO: arrow.to_numpy(writable=True) makes a new copy of data.
# Investigate how to directly perform zero-copy into Torch Tensor.
yield [
torch.from_numpy(arr.to_numpy(zero_copy_only=False, writable=True))
for arr in batch.columns
]
37 changes: 37 additions & 0 deletions python/lance/tests/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2022 Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

torch = pytest.importorskip("torch")

from pathlib import Path

import pyarrow as pa

import lance
from lance.pytorch.data import LanceDataset


def test_data_loader(tmp_path: Path):
torch.Tensor([1, 2, 3])
ids = pa.array(range(10))
values = pa.array(range(10, 20))
tab = pa.Table.from_arrays([ids, values], names=["id", "value"])

lance.write_table(tab, tmp_path / "lance")

dataset = LanceDataset(tmp_path / "lance", batch_size=4)
id_batch, value_batch = next(iter(dataset))
assert id_batch.shape == torch.Size([4])
2 changes: 0 additions & 2 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
arrow_library_dirs = pa.get_library_dirs()
numpy_includes = np.get_include()


# TODO allow for custom liblance directory
lance_cpp = Path(__file__).resolve().parent.parent / "cpp"
lance_includes = str(lance_cpp / "include")
Expand All @@ -51,7 +50,6 @@
)
]


this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text()

Expand Down

0 comments on commit 3f192d3

Please sign in to comment.