-
Notifications
You must be signed in to change notification settings - Fork 245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial PyTorch Dataset support #134
Changes from 11 commits
a983562
43f6a27
a100680
0f6ae9e
d4f8f6f
d196781
951d9a7
bf5eb5b
2c248cc
d78b50c
5ebdb62
8ace2a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
|
||
import platform | ||
from pathlib import Path | ||
from typing import Optional, Union | ||
from typing import List, Optional, Union | ||
|
||
import pyarrow as pa | ||
import pyarrow.compute as pc | ||
|
@@ -47,15 +47,44 @@ def dataset( | |
|
||
def scanner( | ||
data: Union[str, Path, ds.Dataset], | ||
columns: Optional[str] = None, | ||
columns: Optional[List[str]] = None, | ||
filter: Optional[pc.Expression] = None, | ||
batch_size: Optional[int] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add docstrings for the function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
limit: Optional[int] = None, | ||
offset: int = 0, | ||
) -> ds.Scanner: | ||
"""Build a PyArrow Dataset scanner. | ||
|
||
It extends PyArrow Scanner with limit pushdown. | ||
|
||
Parameters | ||
---------- | ||
data: uri, path or pyarrow dataset | ||
The Dataset | ||
columns: List[str], optional | ||
Specify the columns to read. | ||
filter: pc.Expression, optional | ||
Apply filter to the scanner. | ||
batch_size: int | ||
The maximum number of records to scan for each batch. | ||
limit: int | ||
Limit the number of records to return in total. | ||
offset: int | ||
The offset to read the data from. | ||
|
||
See Also | ||
-------- | ||
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner | ||
""" | ||
if isinstance(data, (str, Path)): | ||
data = dataset(str(data)) | ||
return BuildScanner( | ||
data, columns=columns, filter=filter, limit=limit, offset=offset | ||
data, | ||
columns=columns, | ||
filter=filter, | ||
batch_size=batch_size, | ||
limit=limit, | ||
offset=offset, | ||
) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,13 +4,22 @@ from typing import Optional, Union | |
|
||
from cython.operator cimport dereference as deref | ||
from libcpp cimport bool | ||
from libcpp.memory cimport shared_ptr, const_pointer_cast | ||
from libcpp.memory cimport const_pointer_cast, shared_ptr | ||
from libcpp.string cimport string | ||
|
||
from pathlib import Path | ||
|
||
from pyarrow import Table | ||
from pyarrow._dataset cimport FileFormat, FileWriteOptions, CFileWriteOptions, CScanner, CDataset, Dataset | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These changes were due to running |
||
from pyarrow._dataset cimport ( | ||
CDataset, | ||
CFileWriteOptions, | ||
CScanner, | ||
Dataset, | ||
FileFormat, | ||
FileWriteOptions, | ||
) | ||
|
||
from pyarrow._dataset import Scanner | ||
|
||
from pyarrow._compute cimport Expression, _bind | ||
|
@@ -75,15 +84,17 @@ cdef extern from "lance/arrow/writer.h" namespace "lance::arrow" nogil: | |
cdef extern from "lance/arrow/scanner.h" namespace "lance::arrow" nogil: | ||
cdef cppclass LScannerBuilder "::lance::arrow::ScannerBuilder": | ||
LScannerBuilder(shared_ptr[CDataset]) except + | ||
void Project(const vector[string]& columns) | ||
void Filter(CExpression filter) | ||
void Limit(int64_t limit, int64_t offset) | ||
CStatus Project(const vector[string]& columns) | ||
CStatus Filter(CExpression filter) | ||
CStatus BatchSize(int64_t batch_size) | ||
CStatus Limit(int64_t limit, int64_t offset) | ||
CResult[shared_ptr[CScanner]] Finish() | ||
|
||
def BuildScanner( | ||
dataset: Dataset, | ||
columns: Optional[list[str]] = None, | ||
filter: Optional[Expression] = None, | ||
batch_size: Optional[int] = None, | ||
limit: Optional[int] = None, | ||
offset: int = 0, | ||
): | ||
|
@@ -95,9 +106,10 @@ def BuildScanner( | |
builder.get().Project([tobytes(c) for c in columns]) | ||
if filter is not None: | ||
builder.get().Filter(_bind(filter, dataset.schema)) | ||
if batch_size is not None: | ||
builder.get().BatchSize(batch_size) | ||
if limit is not None: | ||
builder.get().Limit(limit, offset) | ||
|
||
scanner = GetResultValue(builder.get().Finish()) | ||
creader = GetResultValue(scanner.get().ToRecordBatchReader()) | ||
reader = RecordBatchReader() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2022 Lance Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright 2022 Lance Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from pathlib import Path | ||
from typing import List, Optional, Union | ||
|
||
import pyarrow as pa | ||
import pyarrow.dataset | ||
|
||
try: | ||
import torch | ||
from torch.utils.data import IterableDataset | ||
except ImportError as e: | ||
raise ImportError("Please install pytorch", e) | ||
|
||
from lance import dataset, scanner | ||
|
||
__all__ = ["LanceDataset"] | ||
|
||
|
||
class LanceDataset(IterableDataset): | ||
"""An PyTorch IterableDataset. | ||
|
||
See: | ||
https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset | ||
""" | ||
|
||
def __init__( | ||
self, | ||
uri: Union[str, Path], | ||
columns: Optional[List[str]] = None, | ||
batch_size: Optional[int] = None, | ||
): | ||
self.uri = uri | ||
self.columns = columns if columns else [] | ||
self.batch_size = batch_size | ||
self.scanner: pa.dataset.Scanner = scanner( | ||
dataset(self.uri), columns=columns, batch_size=batch_size | ||
) | ||
|
||
def __repr__(self): | ||
return f"LanceDataset(uri={self.uri})" | ||
|
||
def __iter__(self): | ||
"""Yield dataset""" | ||
for batch in self.scanner.to_reader(): | ||
# TODO: arrow.to_numpy(writable=True) makes a new copy of data. | ||
# Investigate how to directly perform zero-copy into Torch Tensor. | ||
yield [ | ||
torch.from_numpy(arr.to_numpy(zero_copy_only=False, writable=True)) | ||
for arr in batch.columns | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,10 +16,10 @@ | |
from pathlib import Path | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changes in this whole file is just formatting right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, that's right. |
||
|
||
import pandas as pd | ||
|
||
import pyarrow as pa | ||
import pyarrow.dataset as ds | ||
from lance import write_table, dataset, LanceFileFormat | ||
|
||
from lance import LanceFileFormat, dataset, write_table | ||
|
||
|
||
def test_simple_round_trips(tmp_path: Path): | ||
|
@@ -48,7 +48,6 @@ def test_write_categorical_values(tmp_path: Path): | |
assert table == actual | ||
|
||
|
||
|
||
def test_write_dataset(tmp_path: Path): | ||
table = pa.Table.from_pandas( | ||
pd.DataFrame( | ||
|
@@ -59,12 +58,7 @@ def test_write_dataset(tmp_path: Path): | |
} | ||
) | ||
) | ||
ds.write_dataset( | ||
table, | ||
tmp_path, | ||
partitioning=["split"], | ||
format=LanceFileFormat() | ||
) | ||
ds.write_dataset(table, tmp_path, partitioning=["split"], format=LanceFileFormat()) | ||
|
||
part_dirs = [d.name for d in tmp_path.iterdir()] | ||
assert set(part_dirs) == set(["a", "b"]) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright 2022 Lance Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytest | ||
|
||
torch = pytest.importorskip("torch") | ||
|
||
from pathlib import Path | ||
|
||
import pyarrow as pa | ||
|
||
import lance | ||
from lance.pytorch.data import LanceDataset | ||
|
||
|
||
def test_data_loader(tmp_path: Path): | ||
torch.Tensor([1, 2, 3]) | ||
ids = pa.array(range(10)) | ||
values = pa.array(range(10, 20)) | ||
tab = pa.Table.from_arrays([ids, values], names=["id", "value"]) | ||
|
||
lance.write_table(tab, tmp_path / "lance") | ||
|
||
dataset = LanceDataset(tmp_path / "lance", batch_size=4) | ||
id_batch, value_batch = next(iter(dataset)) | ||
assert id_batch.shape == torch.Size([4]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
install pytorch directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.