Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Tf.data pipeline #1087

Merged
merged 25 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
workspaces: python
- name: Install linting tools
run: |
pip install black isort ruff maturin
pip install black isort ruff maturin tensorflow
- name: Lint Python
run: |
black --check python
Expand Down Expand Up @@ -157,4 +157,4 @@ jobs:
with:
workspaces: python
- uses: ./.github/workflows/build_windows_wheel
- uses: ./.github/workflows/run_tests
- uses: ./.github/workflows/run_tests
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ duckdb>=0.8
jupyterlab
fastai
xmltodict
tensorflow
1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ tests = [
"duckdb",
"polars[pyarrow,pandas]",
"ml_dtypes",
"tensorflow",
]

[tool.isort]
Expand Down
1 change: 1 addition & 0 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def get_fragments(
) -> Iterator[pa.dataset.Fragment]:
"""Get all fragments from the dataset.


Note: filter is not supported yet.
"""
if filter is not None:
Expand Down
21 changes: 21 additions & 0 deletions python/python/lance/tf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2023. Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.util

if importlib.util.find_spec("tensorflow") is None:
raise ImportError(
"Tensorflow is not installed. Please install tensorflow"
+ " to use lance.tf module.",
)
229 changes: 229 additions & 0 deletions python/python/lance/tf/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (c) 2023. Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Tensorflow Dataset (`tf.data <https://www.tensorflow.org/guide/data>`_)
implementation for Lance.

.. warning::

Experimental feature. API stability is not guaranteed.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union

import lance
import numpy as np
import pyarrow as pa
import tensorflow as tf
from lance import LanceDataset
from lance.fragment import LanceFragment


def arrow_data_type_to_tf(dt: pa.DataType) -> tf.DType:
"""Convert Pyarrow DataType to Tensorflow."""
if pa.types.is_boolean(dt):
return tf.bool
elif pa.types.is_int8(dt):
return tf.int8
elif pa.types.is_int16(dt):
return tf.int16
elif pa.types.is_int32(dt):
return tf.int32
elif pa.types.is_int64(dt):
return tf.int64
elif pa.types.is_uint8(dt):
return tf.uint8
elif pa.types.is_uint16(dt):
return tf.uint16
elif pa.types.is_uint32(dt):
return tf.uint32
elif pa.types.is_uint64(dt):
return tf.uint64
elif pa.types.is_float16(dt):
return tf.float16
elif pa.types.is_float32(dt):
return tf.float32
elif pa.types.is_float64(dt):
return tf.float64
elif (
pa.types.is_string(dt)
or pa.types.is_large_string(dt)
or pa.types.is_binary(dt)
or pa.types.is_large_binary(dt)
):
return tf.string

raise TypeError(f"Arrow/Tf conversion: Unsupported arrow data type: {dt}")


def data_type_to_tensor_spec(dt: pa.DataType) -> tf.TensorSpec:
"""Convert PyArrow DataType to Tensorflow TensorSpec."""
if (
pa.types.is_boolean(dt)
or pa.types.is_integer(dt)
or pa.types.is_floating(dt)
or pa.types.is_string(dt)
):
return tf.TensorSpec(shape=(None,), dtype=arrow_data_type_to_tf(dt))
elif pa.types.is_fixed_size_list(dt):
return tf.TensorSpec(
shape=(None, dt.list_size), dtype=arrow_data_type_to_tf(dt.value_type)
)
elif pa.types.is_list(dt) or pa.types.is_large_list(dt):
return tf.TensorSpec(
shape=(
None,
None,
wjones127 marked this conversation as resolved.
Show resolved Hide resolved
),
dtype=arrow_data_type_to_tf(dt.value_type),
)

raise TypeError("Unsupported data type: ", dt)


def schema_to_spec(schema: pa.Schema) -> tf.TypeSpec:
"""Convert PyArrow Schema to Tensorflow output signature."""
signature = {}
for name in schema.names:
field = schema.field(name)
signature[name] = data_type_to_tensor_spec(field.type)
return signature


def column_to_tensor(column: pa.Array, tensor_spec: tf.TensorSpec) -> tf.Tensor:
if isinstance(tensor_spec, tf.RaggedTensorSpec):
return tf.ragged.constant(column, dtype=tensor_spec.dtype)
else:
return tf.constant(column, dtype=tensor_spec.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any docs on how this works? We can just pass arrow arrays into tf.constant and tensorflow understands arrow? Or is there some other protocol here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://www.tensorflow.org/tutorials/customization/basics#numpy_compatibility it is not yet clear to me that this is a zero copy operation yet.

Let's digg more in implementation details as follow up? Wanted to get the API spec out for now.



def from_lance(
dataset: Union[str, Path, LanceDataset],
*,
columns: Optional[List[str]] = None,
batch_size: int = 256,
filter: Optional[str] = None,
fragments: Union[Iterable[LanceFragment], tf.data.Dataset] = None,
output_signature: Optional[Dict[str, tf.TypeSpec]] = None,
) -> tf.data.Dataset:
"""Create a ``tf.data.Dataset`` from a Lance dataset.

Parameters
----------
dataset : Union[str, Path, LanceDataset]
Lance dataset or dataset URI/path.
columns : Optional[List[str]], optional
List of columns to include in the output dataset.
If not set, all columns will be read.
batch_size : int, optional
Batch size, by default 256
filter : Optional[str], optional
SQL filter expression, by default None.
fragments : Union[List[LanceFragment], tf.data.Dataset], optional
If provided, only the fragments are read. It can be used to feed
for distributed training.
output_signature : Optional[tf.TypeSpec], optional
Override output signature of the returned tensors. If not provided,
the output signature is inferred from the projection Schema.

Examples
--------
wjones127 marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

import tensorflow as tf
import lance.tf.data

ds = lance.tf.data.from_lance(
"s3://bucket/path",
columns=["image", "id"],
filter="catalog = 'train' AND split = 'train'",
batch_size=100)

for batch in ds.repeat(10).shuffle(128).map(io_decode):
print(batch["image"].shape)

``from_lance`` can takes a iterator or ``tf.data.Dataset`` of
Fragments. So that it can be used to feed for distributed training.

.. code-block:: python

import tensorflow as tf
import lance.tf.data

seed = 200 # seed to shuffle the fragments in distributed machines.
fragments = lance.tf.data.lance_fragments("s3://bucket/path")
repeat(10).shuffle(4, seed=seed)
ds = lance.tf.data.from_lance(
"s3://bucket/path",
columns=["image", "id"],
filter="catalog = 'train' AND split = 'train'",
fragments=fragments,
batch_size=100)
for batch in ds.shuffle(128).map(io_decode):
print(batch["image"].shape)

"""
if not isinstance(dataset, LanceDataset):
dataset = lance.dataset(dataset)
if isinstance(fragments, tf.data.Dataset):
fragments = list(fragments.as_numpy_iterator())
elif isinstance(fragments, np.ndarray):
fragments = list(fragments)

if fragments is not None:
# A Generator of Fragments
fragments = (LanceFragment(dataset, f) for f in fragments)
scanner = dataset.scanner(
filter=filter, columns=columns, batch_size=batch_size, fragments=fragments
)

if output_signature is None:
schema = scanner.projected_schema
output_signature = schema_to_spec(schema)
logging.debug("Output signature: %s", output_signature)

def generator():
for batch in scanner.to_batches():
data = batch.to_pydict()
yield {
name: column_to_tensor(column, output_signature[name])
for name, column in data.items()
}

return tf.data.Dataset.from_generator(generator, output_signature=output_signature)


def lance_fragments(dataset: Union[str, Path, LanceDataset]) -> tf.data.Dataset:
"""Create a ``tf.data.Dataset`` of Lance Fragments in the dataset.

Parameters
----------
dataset : Union[str, Path, LanceDataset]
A Lance Dataset or dataset URI/path.
"""
if not isinstance(dataset, LanceDataset):
dataset = lance.dataset(dataset)
return tf.data.Dataset.from_tensor_slices(
[f.fragment_id for f in dataset.get_fragments()]
)


# Register `from_lance` to ``tf.data.Dataset``.
tf.data.Dataset.from_lance = from_lance
Loading