Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial PyTorch Dataset support #134

Merged
merged 12 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
run: ./tools/build_wheel.sh $(echo cp${{ matrix.python-version }} | sed "s/\.//")
- name: Pip install
run: |
python -m pip install $(ls wheels/*.whl)[test]
python -m pip install $(ls wheels/*.whl)[test,pytorch]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

install pytorch directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

- name: Run python tests
run: |
pytest -x -v --durations=10
Expand Down
35 changes: 32 additions & 3 deletions python/lance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import platform
from pathlib import Path
from typing import Optional, Union
from typing import List, Optional, Union

import pyarrow as pa
import pyarrow.compute as pc
Expand Down Expand Up @@ -47,15 +47,44 @@ def dataset(

def scanner(
data: Union[str, Path, ds.Dataset],
columns: Optional[str] = None,
columns: Optional[List[str]] = None,
filter: Optional[pc.Expression] = None,
batch_size: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add docstrings for the function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

limit: Optional[int] = None,
offset: int = 0,
) -> ds.Scanner:
"""Build a PyArrow Dataset scanner.

It extends PyArrow Scanner with limit pushdown.

Parameters
----------
data: uri, path or pyarrow dataset
The Dataset
columns: List[str], optional
Specify the columns to read.
filter: pc.Expression, optional
Apply filter to the scanner.
batch_size: int
The maximum number of records to scan for each batch.
limit: int
Limit the number of records to return in total.
offset: int
The offset to read the data from.

See Also
--------
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner
"""
if isinstance(data, (str, Path)):
data = dataset(str(data))
return BuildScanner(
data, columns=columns, filter=filter, limit=limit, offset=offset
data,
columns=columns,
filter=filter,
batch_size=batch_size,
limit=limit,
offset=offset,
)


Expand Down
24 changes: 18 additions & 6 deletions python/lance/_lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@ from typing import Optional, Union

from cython.operator cimport dereference as deref
from libcpp cimport bool
from libcpp.memory cimport shared_ptr, const_pointer_cast
from libcpp.memory cimport const_pointer_cast, shared_ptr
from libcpp.string cimport string

from pathlib import Path

from pyarrow import Table
from pyarrow._dataset cimport FileFormat, FileWriteOptions, CFileWriteOptions, CScanner, CDataset, Dataset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes were due to running isort

from pyarrow._dataset cimport (
CDataset,
CFileWriteOptions,
CScanner,
Dataset,
FileFormat,
FileWriteOptions,
)

from pyarrow._dataset import Scanner

from pyarrow._compute cimport Expression, _bind
Expand Down Expand Up @@ -75,15 +84,17 @@ cdef extern from "lance/arrow/writer.h" namespace "lance::arrow" nogil:
cdef extern from "lance/arrow/scanner.h" namespace "lance::arrow" nogil:
cdef cppclass LScannerBuilder "::lance::arrow::ScannerBuilder":
LScannerBuilder(shared_ptr[CDataset]) except +
void Project(const vector[string]& columns)
void Filter(CExpression filter)
void Limit(int64_t limit, int64_t offset)
CStatus Project(const vector[string]& columns)
CStatus Filter(CExpression filter)
CStatus BatchSize(int64_t batch_size)
CStatus Limit(int64_t limit, int64_t offset)
CResult[shared_ptr[CScanner]] Finish()

def BuildScanner(
dataset: Dataset,
columns: Optional[list[str]] = None,
filter: Optional[Expression] = None,
batch_size: Optional[int] = None,
limit: Optional[int] = None,
offset: int = 0,
):
Expand All @@ -95,9 +106,10 @@ def BuildScanner(
builder.get().Project([tobytes(c) for c in columns])
if filter is not None:
builder.get().Filter(_bind(filter, dataset.schema))
if batch_size is not None:
builder.get().BatchSize(batch_size)
if limit is not None:
builder.get().Limit(limit, offset)

scanner = GetResultValue(builder.get().Finish())
creader = GetResultValue(scanner.get().ToRecordBatchReader())
reader = RecordBatchReader()
Expand Down
13 changes: 13 additions & 0 deletions python/lance/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022 Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
63 changes: 63 additions & 0 deletions python/lance/pytorch/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2022 Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import List, Optional, Union

import pyarrow as pa
import pyarrow.dataset

try:
import torch
from torch.utils.data import IterableDataset
except ImportError as e:
raise ImportError("Please install pytorch", e)

from lance import dataset, scanner

__all__ = ["LanceDataset"]


class LanceDataset(IterableDataset):
"""An PyTorch IterableDataset.

See:
https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
"""

def __init__(
self,
uri: Union[str, Path],
columns: Optional[List[str]] = None,
batch_size: Optional[int] = None,
):
self.uri = uri
self.columns = columns if columns else []
self.batch_size = batch_size
self.scanner: pa.dataset.Scanner = scanner(
dataset(self.uri), columns=columns, batch_size=batch_size
)

def __repr__(self):
return f"LanceDataset(uri={self.uri})"

def __iter__(self):
"""Yield dataset"""
for batch in self.scanner.to_reader():
# TODO: arrow.to_numpy(writable=True) makes a new copy of data.
# Investigate how to directly perform zero-copy into Torch Tensor.
yield [
torch.from_numpy(arr.to_numpy(zero_copy_only=False, writable=True))
for arr in batch.columns
]
12 changes: 3 additions & 9 deletions python/lance/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from pathlib import Path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes in this whole file is just formatting right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's right.


import pandas as pd

import pyarrow as pa
import pyarrow.dataset as ds
from lance import write_table, dataset, LanceFileFormat

from lance import LanceFileFormat, dataset, write_table


def test_simple_round_trips(tmp_path: Path):
Expand Down Expand Up @@ -48,7 +48,6 @@ def test_write_categorical_values(tmp_path: Path):
assert table == actual



def test_write_dataset(tmp_path: Path):
table = pa.Table.from_pandas(
pd.DataFrame(
Expand All @@ -59,12 +58,7 @@ def test_write_dataset(tmp_path: Path):
}
)
)
ds.write_dataset(
table,
tmp_path,
partitioning=["split"],
format=LanceFileFormat()
)
ds.write_dataset(table, tmp_path, partitioning=["split"], format=LanceFileFormat())

part_dirs = [d.name for d in tmp_path.iterdir()]
assert set(part_dirs) == set(["a", "b"])
Expand Down
37 changes: 37 additions & 0 deletions python/lance/tests/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2022 Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

torch = pytest.importorskip("torch")

from pathlib import Path

import pyarrow as pa

import lance
from lance.pytorch.data import LanceDataset


def test_data_loader(tmp_path: Path):
torch.Tensor([1, 2, 3])
ids = pa.array(range(10))
values = pa.array(range(10, 20))
tab = pa.Table.from_arrays([ids, values], names=["id", "value"])

lance.write_table(tab, tmp_path / "lance")

dataset = LanceDataset(tmp_path / "lance", batch_size=4)
id_batch, value_batch = next(iter(dataset))
assert id_batch.shape == torch.Size([4])
6 changes: 2 additions & 4 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from setuptools import Extension, find_packages, setup
import platform
from pathlib import Path

import numpy as np
import pyarrow as pa
from Cython.Build import cythonize
from setuptools import Extension, find_packages, setup

extra_libs = []
# TODO: ciwheelbuild can not find / dont need arrow_python.
Expand All @@ -32,7 +32,6 @@
arrow_library_dirs = pa.get_library_dirs()
numpy_includes = np.get_include()


# TODO allow for custom liblance directory
lance_cpp = Path(__file__).resolve().parent.parent / "cpp"
lance_includes = str(lance_cpp / "include")
Expand All @@ -51,7 +50,6 @@
)
]


this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text()

Expand Down