Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: parse TFRecords as Arrow data #1166

Merged
merged 14 commits into from
Aug 30, 2023
4 changes: 2 additions & 2 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ chrono = "0.4.23"
env_logger = "0.10"
futures = "0.3"
half = { version = "2.1", default-features = false, features = ["num-traits"] }
lance = { path = "../rust" }
lance = { path = "../rust", features = ["tensorflow"] }
lazy_static = "1"
log = "0.4"
prost = "0.11"
prost = "0.10"
westonpace marked this conversation as resolved.
Show resolved Hide resolved
pyo3 = { version = "0.19", features = ["extension-module", "abi3-py38"] }
tokio = { version = "1.23", features = ["rt-multi-thread"] }
uuid = "1.3.0"
Expand Down
10 changes: 10 additions & 0 deletions python/python/lance/lance.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import List, Optional

import pyarrow as pa

def infer_tfrecord_schema(
uri: str,
tensor_features: Optional[List[str]] = None,
string_features: Optional[List[str]] = None,
) -> pa.Schema: ...
def read_tfrecord(uri: str, schema: pa.Schema) -> pa.RecordBatchReader: ...
16 changes: 16 additions & 0 deletions python/python/lance/tf/tfrecord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2023. Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from lance.lance import infer_tfrecord_schema as infer_tfrecord_schema
from lance.lance import read_tfrecord as read_tfrecord
227 changes: 227 additions & 0 deletions python/python/tests/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ml_dtypes
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from lance.arrow import BFloat16Type, bfloat16_array

try:
import tensorflow as tf # noqa: F401
Expand All @@ -28,6 +31,7 @@
import lance
from lance.fragment import LanceFragment
from lance.tf.data import from_lance, lance_fragments
from lance.tf.tfrecord import infer_tfrecord_schema, read_tfrecord


@pytest.fixture
Expand Down Expand Up @@ -148,3 +152,226 @@ def test_var_length_list(tmp_path):
assert batch["a"].numpy()[0] == idx * 8
assert batch["l"].shape == (8, None)
assert isinstance(batch["l"], tf.RaggedTensor)


@pytest.fixture
def sample_tf_example():
# Create a TFRecord with a string, float, int, and a tensor
tensor = tf.constant(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32))
tensor_bf16 = tf.constant(
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=ml_dtypes.bfloat16)
)

feature = {
"1_int": tf.train.Feature(int64_list=tf.train.Int64List(value=[1])),
"2_int_list": tf.train.Feature(int64_list=tf.train.Int64List(value=[1, 2, 3])),
"3_float": tf.train.Feature(float_list=tf.train.FloatList(value=[1.0])),
"4_float_list": tf.train.Feature(
float_list=tf.train.FloatList(value=[1.0, 2.0, 3.0])
),
"5_bytes": tf.train.Feature(
bytes_list=tf.train.BytesList(value=[b"Hello, TensorFlow!"])
),
"6_bytes_list": tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[b"Hello, TensorFlow!", b"Hello, Lance!"]
)
),
"7_string": tf.train.Feature(
bytes_list=tf.train.BytesList(value=[b"Hello, TensorFlow!"])
),
"8_tensor": tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[tf.io.serialize_tensor(tensor).numpy()]
)
),
"9_tensor_bf16": tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[tf.io.serialize_tensor(tensor_bf16).numpy()]
)
),
}

return tf.train.Example(features=tf.train.Features(feature=feature))


def test_tfrecord_parsing(tmp_path, sample_tf_example):
serialized = sample_tf_example.SerializeToString()

path = tmp_path / "test.tfrecord"
with tf.io.TFRecordWriter(str(path)) as writer:
writer.write(serialized)

inferred_schema = infer_tfrecord_schema(str(path))

assert inferred_schema == pa.schema(
{
"1_int": pa.int64(),
"2_int_list": pa.list_(pa.int64()),
"3_float": pa.float32(),
"4_float_list": pa.list_(pa.float32()),
"5_bytes": pa.binary(),
"6_bytes_list": pa.list_(pa.binary()),
# tensors and strings assumed binary
"7_string": pa.binary(),
"8_tensor": pa.binary(),
"9_tensor_bf16": pa.binary(),
}
)

inferred_schema = infer_tfrecord_schema(
str(path),
tensor_features=["8_tensor", "9_tensor_bf16"],
string_features=["7_string"],
)
assert inferred_schema == pa.schema(
{
"1_int": pa.int64(),
"2_int_list": pa.list_(pa.int64()),
"3_float": pa.float32(),
"4_float_list": pa.list_(pa.float32()),
"5_bytes": pa.binary(),
"6_bytes_list": pa.list_(pa.binary()),
"7_string": pa.string(),
"8_tensor": pa.fixed_shape_tensor(pa.float32(), [2, 3]),
"9_tensor_bf16": pa.fixed_shape_tensor(BFloat16Type(), [2, 3]),
}
)

reader = read_tfrecord(str(path), inferred_schema)
assert reader.schema == inferred_schema
table = reader.read_all()

assert table.schema == inferred_schema

tensor_type = pa.fixed_shape_tensor(pa.float32(), [2, 3])
inner = pa.array([float(x) for x in range(1, 7)], pa.float32())
storage = pa.FixedSizeListArray.from_arrays(inner, 6)
f32_array = pa.ExtensionArray.from_storage(tensor_type, storage)

tensor_type = pa.fixed_shape_tensor(BFloat16Type(), [2, 3])
bf16_array = bfloat16_array([float(x) for x in range(1, 7)])
storage = pa.FixedSizeListArray.from_arrays(bf16_array, 6)
bf16_array = pa.ExtensionArray.from_storage(tensor_type, storage)

expected_data = pa.table(
{
"1_int": pa.array([1]),
"2_int_list": pa.array([[1, 2, 3]]),
"3_float": pa.array([1.0], pa.float32()),
"4_float_list": pa.array([[1.0, 2.0, 3.0]], pa.list_(pa.float32())),
"5_bytes": pa.array([b"Hello, TensorFlow!"]),
"6_bytes_list": pa.array([[b"Hello, TensorFlow!", b"Hello, Lance!"]]),
"7_string": pa.array(["Hello, TensorFlow!"]),
"8_tensor": f32_array,
"9_tensor_bf16": bf16_array,
}
)

assert table == expected_data


def test_tfrecord_roundtrip(tmp_path, sample_tf_example):
del sample_tf_example.features.feature["9_tensor_bf16"]

serialized = sample_tf_example.SerializeToString()

path = tmp_path / "test.tfrecord"
with tf.io.TFRecordWriter(str(path)) as writer:
writer.write(serialized)

schema = infer_tfrecord_schema(
str(path),
tensor_features=["8_tensor", "9_tensor_bf16"],
string_features=["7_string"],
)

table = read_tfrecord(str(path), schema).read_all()

# Can roundtrip to lance
dataset_uri = tmp_path / "dataset"
dataset = lance.write_dataset(table, dataset_uri)
assert dataset.schema == table.schema
assert dataset.to_table() == table

# TODO: validate we can roundtrip with from_lance()
# tf_ds = from_lance(dataset, batch_size=1)


def test_tfrecord_parsing_nulls(tmp_path):
# Make sure we don't trip up on missing values
tensor = tf.constant(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32))

features = [
{
"a": tf.train.Feature(int64_list=tf.train.Int64List(value=[1])),
"b": tf.train.Feature(int64_list=tf.train.Int64List(value=[1])),
"c": tf.train.Feature(float_list=tf.train.FloatList(value=[1.0])),
"d": tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[tf.io.serialize_tensor(tensor).numpy()]
)
),
},
{
"a": tf.train.Feature(int64_list=tf.train.Int64List(value=[1])),
},
{
"a": tf.train.Feature(int64_list=tf.train.Int64List(value=[1])),
"b": tf.train.Feature(int64_list=tf.train.Int64List(value=[1, 2, 3])),
"c": tf.train.Feature(float_list=tf.train.FloatList(value=[1.0])),
},
]

path = tmp_path / "test.tfrecord"
with tf.io.TFRecordWriter(str(path)) as writer:
for feature in features:
example_proto = tf.train.Example(
features=tf.train.Features(feature=feature)
)
serialized = example_proto.SerializeToString()
writer.write(serialized)
westonpace marked this conversation as resolved.
Show resolved Hide resolved

inferred_schema = infer_tfrecord_schema(str(path), tensor_features=["d"])
assert inferred_schema == pa.schema(
{
"a": pa.int64(),
"b": pa.list_(pa.int64()),
"c": pa.float32(),
"d": pa.fixed_shape_tensor(pa.float32(), [2, 3]),
}
)

tensor_type = pa.fixed_shape_tensor(pa.float32(), [2, 3])
inner = pa.array([float(x) for x in range(1, 7)] + [None] * 12, pa.float32())
storage = pa.FixedSizeListArray.from_arrays(inner, 6)
f32_array = pa.ExtensionArray.from_storage(tensor_type, storage)

data = read_tfrecord(str(path), inferred_schema).read_all()
expected = pa.table(
{
"a": pa.array([1, 1, 1]),
"b": pa.array([[1], [], [1, 2, 3]]),
"c": pa.array([1.0, None, 1.0], pa.float32()),
"d": f32_array,
}
)

assert data == expected

# can do projection
read_schema = pa.schema(
{
"a": pa.int64(),
"c": pa.float32(),
}
)
expected = pa.table(
{
"a": pa.array([1, 1, 1]),
"c": pa.array([1.0, None, 1.0], pa.float32()),
}
)

data = read_tfrecord(str(path), read_schema).read_all()
assert data == expected
20 changes: 20 additions & 0 deletions python/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,26 @@ impl BackgroundExecutor {
rx.recv().unwrap()
}

/// Spawn a task in the background
pub fn spawn_background<T>(&self, py: Option<Python<'_>>, task: T)
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
if let Some(py) = py {
py.allow_threads(|| {
self.runtime.spawn(task);
})
} else {
// Python::with_gil is a no-op if the GIL is already held by the thread.
Python::with_gil(|py| {
py.allow_threads(|| {
self.runtime.spawn(task);
})
})
}
}

Comment on lines +78 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's used in read_tfrecord. The pattern I've found for now for exporting streams as recordbatchreaders is to shove the stream on a background task and have it push onto the iterator via a channel. It's a little awkward but seems to work okay.

/// Block on a future and wait for it to complete.
///
/// This helper method also frees the GIL before blocking.
Expand Down
Loading