diff --git a/.github/workflows/_build_test_upload.yml b/.github/workflows/_build_test_upload.yml index 6e0460675..9243e5269 100644 --- a/.github/workflows/_build_test_upload.yml +++ b/.github/workflows/_build_test_upload.yml @@ -91,7 +91,7 @@ jobs: pkginfo $pkg done - name: Install Test Requirements - run: pip3 install expecttest fsspec iopath==0.1.9 numpy pytest rarfile + run: pip3 install expecttest fsspec iopath==0.1.9 numpy pytest rarfile protobuf - name: Run DataPipes Tests with pytest run: pytest --no-header -v test --ignore=test/test_period.py --ignore=test/test_text_examples.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7fe6a3075..74c313325 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -96,7 +96,7 @@ jobs: BUILD_S3: ${{ matrix.with-s3 }} AWSSDK_DIR: ${{ steps.export_path.outputs.awssdk }} - name: Install test requirements - run: pip3 install expecttest fsspec iopath==0.1.9 numpy pytest rarfile + run: pip3 install expecttest fsspec iopath==0.1.9 numpy pytest rarfile protobuf - name: Run DataPipes tests with pytest if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} run: diff --git a/test/_fakedata/_create_fake_data.py b/test/_fakedata/_create_fake_data.py index daefc82ed..47cb88dd5 100644 --- a/test/_fakedata/_create_fake_data.py +++ b/test/_fakedata/_create_fake_data.py @@ -35,6 +35,61 @@ def create_files(folder, suffix, data, encoding=False): archive.add(folder) +def create_tfrecord_files(path: str): + try: + import tensorflow as tf + except ImportError: + print("TensorFlow not found!") + print("We will not generate tfrecord files.") + return + + os.makedirs(path, exist_ok=True) + with tf.io.TFRecordWriter(os.path.join(path, "example.tfrecord")) as writer: + for i in range(4): + x = tf.range(i * 10, (i + 1) * 10) + record_bytes = tf.train.Example( + features=tf.train.Features( + feature={ + "x_float": tf.train.Feature(float_list=tf.train.FloatList(value=x)), + "x_int": tf.train.Feature(int64_list=tf.train.Int64List(value=tf.cast(x * 10, "int64"))), + "x_byte": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"test str"])), + } + ) + ).SerializeToString() + writer.write(record_bytes) + + with tf.io.TFRecordWriter(os.path.join(path, "sequence_example.tfrecord")) as writer: + for i in range(4): + x = tf.range(i * 10, (i + 1) * 10) + rep = 2 * i + 3 + + record_bytes = tf.train.SequenceExample( + context=tf.train.Features( + feature={ + "x_float": tf.train.Feature(float_list=tf.train.FloatList(value=x)), + "x_int": tf.train.Feature(int64_list=tf.train.Int64List(value=tf.cast(x * 10, "int64"))), + "x_byte": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"test str"])), + } + ), + feature_lists=tf.train.FeatureLists( + feature_list={ + "x_float_seq": tf.train.FeatureList( + feature=[tf.train.Feature(float_list=tf.train.FloatList(value=x))] * rep + ), + "x_int_seq": tf.train.FeatureList( + feature=[tf.train.Feature(int64_list=tf.train.Int64List(value=tf.cast(x * 10, "int64")))] + * rep + ), + "x_byte_seq": tf.train.FeatureList( + feature=[tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"test str"]))] * rep + ), + } + ), + ).SerializeToString() + writer.write(record_bytes) + + if __name__ == "__main__": for args in FILES: create_files(*args) + create_tfrecord_files("tfrecord") diff --git a/test/_fakedata/tfrecord/example.tfrecord b/test/_fakedata/tfrecord/example.tfrecord new file mode 100644 index 000000000..a1b7156dc Binary files /dev/null and b/test/_fakedata/tfrecord/example.tfrecord differ diff --git a/test/_fakedata/tfrecord/sequence_example.tfrecord b/test/_fakedata/tfrecord/sequence_example.tfrecord new file mode 100644 index 000000000..a68731a97 Binary files /dev/null and b/test/_fakedata/tfrecord/sequence_example.tfrecord differ diff --git a/test/test_serialization.py b/test/test_serialization.py index 8e4d9498f..f746ea309 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -256,6 +256,7 @@ def test_serializable(self): {"mode": "wb", "filepath_fn": partial(_filepath_fn, dir=self.temp_dir.name)}, ), (iterdp.TarArchiveLoader, None, (), {}), + (iterdp.TFRecordLoader, None, (), {}), (iterdp.UnZipper, IterableWrapper([(i, i + 10) for i in range(10)]), (), {"sequence_length": 2}), (iterdp.XzFileLoader, None, (), {}), (iterdp.ZipArchiveLoader, None, (), {}), @@ -281,6 +282,7 @@ def test_serializable(self): iterdp.SampleMultiplexer, iterdp.RarArchiveLoader, iterdp.TarArchiveLoader, + iterdp.TFRecordLoader, iterdp.XzFileLoader, iterdp.ZipArchiveLoader, } diff --git a/test/test_tfrecord.py b/test/test_tfrecord.py new file mode 100644 index 000000000..186b2af41 --- /dev/null +++ b/test/test_tfrecord.py @@ -0,0 +1,274 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import unittest +import warnings +from functools import partial + +import expecttest +import numpy as np + +import torch + +from _utils._common_utils_for_test import reset_after_n_next_calls +from torchdata.datapipes.iter import ( + FileLister, + FileOpener, + FSSpecFileLister, + FSSpecFileOpener, + FSSpecSaver, + IterableWrapper, + TFRecordLoader, +) + + +class TestDataPipeTFRecord(expecttest.TestCase): + def setUp(self): + self.temp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "_fakedata", "tfrecord") + + def assertArrayEqual(self, arr1, arr2): + np.testing.assert_array_equal(arr1, arr2) + + def _ground_truth_data(self): + for i in range(4): + x = torch.range(i * 10, (i + 1) * 10 - 1) + yield { + "x_float": x, + "x_int": (x * 10).long(), + "x_byte": [b"test str"], + } + + def _ground_truth_seq_data(self): + for i in range(4): + x = torch.range(i * 10, (i + 1) * 10 - 1) + rep = 2 * i + 3 + yield {"x_float": x, "x_int": (x * 10).long(), "x_byte": [b"test str"]}, { + "x_float_seq": [x] * rep, + "x_int_seq": [(x * 10).long()] * rep, + "x_byte_seq": [[b"test str"]] * rep, + } + + @torch.no_grad() + def test_tfrecord_loader_example_iterdatapipe(self): + filename = f"{self.temp_dir}/example.tfrecord" + datapipe1 = IterableWrapper([filename]) + datapipe2 = FileOpener(datapipe1, mode="b") + + # Functional Test: test if the returned data is correct + tfrecord_parser = datapipe2.load_from_tfrecord() + result = list(tfrecord_parser) + self.assertEqual(len(result), 4) + expected_res = final_expected_res = list(self._ground_truth_data()) + for true_data, loaded_data in zip(expected_res, result): + self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) + for key in ["x_float", "x_int"]: + self.assertArrayEqual(true_data[key].numpy(), loaded_data[key].numpy()) + self.assertEqual(len(loaded_data["x_byte"]), 1) + self.assertEqual(true_data["x_byte"][0], loaded_data["x_byte"][0]) + + # Functional Test: test if the shape of the returned data is correct when using spec + tfrecord_parser = datapipe2.load_from_tfrecord( + { + "x_float": ((5, 2), torch.float64), + "x_int": ((5, 2), torch.int32), + "x_byte": (tuple(), None), + } + ) + result = list(tfrecord_parser) + self.assertEqual(len(result), 4) + expected_res = [ + { + "x_float": x["x_float"].reshape(5, 2), + "x_int": x["x_int"].reshape(5, 2), + "x_byte": x["x_byte"][0], + } + for x in self._ground_truth_data() + ] + for true_data, loaded_data in zip(expected_res, result): + self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) + self.assertArrayEqual(true_data["x_float"].numpy(), loaded_data["x_float"].float().numpy()) + self.assertArrayEqual(true_data["x_int"].numpy(), loaded_data["x_int"].long().numpy()) + self.assertEqual(loaded_data["x_float"].dtype, torch.float64) + self.assertEqual(loaded_data["x_int"].dtype, torch.int32) + self.assertEqual(true_data["x_byte"], loaded_data["x_byte"]) + + # Functional Test: ignore features missing from spec + tfrecord_parser = datapipe2.load_from_tfrecord( + { + "x_float": ((10,), torch.float32), + } + ) + result = list(tfrecord_parser) + self.assertEqual(len(result), 4) + expected_res = [ + { + "x_float": x["x_float"], + } + for x in self._ground_truth_data() + ] + for true_data, loaded_data in zip(expected_res, result): + self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) + self.assertArrayEqual(true_data["x_float"].numpy(), loaded_data["x_float"].float().numpy()) + + # Functional Test: raises error if missing spec feature + with self.assertRaises(RuntimeError): + tfrecord_parser = datapipe2.load_from_tfrecord( + { + "x_float_unknown": ((5, 2), torch.float64), + "x_int": ((5, 2), torch.int32), + "x_byte": (tuple(), None), + } + ) + result = list(tfrecord_parser) + + # Reset Test: + tfrecord_parser = TFRecordLoader(datapipe2) + expected_res = final_expected_res + n_elements_before_reset = 2 + res_before_reset, res_after_reset = reset_after_n_next_calls(tfrecord_parser, n_elements_before_reset) + self.assertEqual(len(expected_res[:n_elements_before_reset]), len(res_before_reset)) + for true_data, loaded_data in zip(expected_res[:n_elements_before_reset], res_before_reset): + self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) + for key in ["x_float", "x_int"]: + self.assertArrayEqual(true_data[key].numpy(), loaded_data[key].numpy()) + self.assertEqual(true_data["x_byte"][0], loaded_data["x_byte"][0]) + self.assertEqual(len(expected_res), len(res_after_reset)) + for true_data, loaded_data in zip(expected_res, res_after_reset): + self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) + for key in ["x_float", "x_int"]: + self.assertArrayEqual(true_data[key].numpy(), loaded_data[key].numpy()) + self.assertEqual(true_data["x_byte"][0], loaded_data["x_byte"][0]) + + # __len__ Test: length isn't implemented since it cannot be known ahead of time + with self.assertRaisesRegex(TypeError, "doesn't have valid length"): + len(tfrecord_parser) + + @torch.no_grad() + def test_tfrecord_loader_sequence_example_iterdatapipe(self): + filename = f"{self.temp_dir}/sequence_example.tfrecord" + datapipe1 = IterableWrapper([filename]) + datapipe2 = FileOpener(datapipe1, mode="b") + + # Functional Test: test if the returned data is correct + tfrecord_parser = datapipe2.load_from_tfrecord() + result = list(tfrecord_parser) + self.assertEqual(len(result), 4) + expected_res = final_expected_res = list(self._ground_truth_seq_data()) + for (true_data_ctx, true_data_seq), loaded_data in zip(expected_res, result): + self.assertSetEqual(set(true_data_ctx.keys()).union(true_data_seq.keys()), set(loaded_data.keys())) + for key in ["x_float", "x_int"]: + self.assertArrayEqual(true_data_ctx[key].numpy(), loaded_data[key].numpy()) + self.assertEqual(len(true_data_seq[key + "_seq"]), len(loaded_data[key + "_seq"])) + self.assertIsInstance(loaded_data[key + "_seq"], list) + for a1, a2 in zip(true_data_seq[key + "_seq"], loaded_data[key + "_seq"]): + self.assertArrayEqual(a1, a2) + self.assertEqual(true_data_ctx["x_byte"], loaded_data["x_byte"]) + self.assertListEqual(true_data_seq["x_byte_seq"], loaded_data["x_byte_seq"]) + + # Functional Test: test if the shape of the returned data is correct when using spec + tfrecord_parser = datapipe2.load_from_tfrecord( + { + "x_float": ((5, 2), torch.float64), + "x_int": ((5, 2), torch.int32), + "x_byte": (tuple(), None), + "x_float_seq": ((-1, 5, 2), torch.float64), + "x_int_seq": ((-1, 5, 2), torch.int32), + "x_byte_seq": ((-1,), None), + } + ) + result = list(tfrecord_parser) + self.assertEqual(len(result), 4) + + expected_res = [ + ( + { + "x_float": x["x_float"].reshape(5, 2), + "x_int": x["x_int"].reshape(5, 2), + "x_byte": x["x_byte"][0], + }, + { + "x_float_seq": [y.reshape(5, 2).numpy() for y in z["x_float_seq"]], + "x_int_seq": [y.reshape(5, 2).numpy() for y in z["x_int_seq"]], + "x_byte_seq": [y[0] for y in z["x_byte_seq"]], + }, + ) + for x, z in self._ground_truth_seq_data() + ] + for (true_data_ctx, true_data_seq), loaded_data in zip(expected_res, result): + self.assertSetEqual(set(true_data_ctx.keys()).union(true_data_seq.keys()), set(loaded_data.keys())) + for key in ["x_float", "x_int"]: + l_loaded_data = loaded_data[key] + if key == "x_float": + l_loaded_data = l_loaded_data.float() + else: + l_loaded_data = l_loaded_data.int() + self.assertArrayEqual(true_data_ctx[key].numpy(), l_loaded_data.numpy()) + self.assertArrayEqual(true_data_seq[key + "_seq"], loaded_data[key + "_seq"]) + self.assertEqual(true_data_ctx["x_byte"], loaded_data["x_byte"]) + self.assertListEqual(true_data_seq["x_byte_seq"], loaded_data["x_byte_seq"]) + + # Functional Test: ignore features missing from spec + tfrecord_parser = datapipe2.load_from_tfrecord( + { + "x_float": ((10,), torch.float32), + } + ) + result = list(tfrecord_parser) + self.assertEqual(len(result), 4) + expected_res = [ + { + "x_float": x["x_float"], + } + for x, z in self._ground_truth_seq_data() + ] + for true_data, loaded_data in zip(expected_res, result): + self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) + self.assertArrayEqual(true_data["x_float"].numpy(), loaded_data["x_float"].float().numpy()) + + # Functional Test: raises error if missing spec feature + with self.assertRaises(RuntimeError): + tfrecord_parser = datapipe2.load_from_tfrecord( + {"x_float_unknown": ((5, 2), torch.float64), "x_int": ((5, 2), torch.int32), "x_byte": None} + ) + result = list(tfrecord_parser) + + # Reset Test: + tfrecord_parser = TFRecordLoader(datapipe2) + expected_res = final_expected_res + n_elements_before_reset = 2 + res_before_reset, res_after_reset = reset_after_n_next_calls(tfrecord_parser, n_elements_before_reset) + self.assertEqual(len(expected_res[:n_elements_before_reset]), len(res_before_reset)) + for (true_data_ctx, true_data_seq), loaded_data in zip( + expected_res[:n_elements_before_reset], res_before_reset + ): + self.assertSetEqual(set(true_data_ctx.keys()).union(true_data_seq.keys()), set(loaded_data.keys())) + for key in ["x_float", "x_int"]: + self.assertArrayEqual(true_data_ctx[key].numpy(), loaded_data[key].numpy()) + self.assertEqual(len(true_data_seq[key + "_seq"]), len(loaded_data[key + "_seq"])) + self.assertIsInstance(loaded_data[key + "_seq"], list) + for a1, a2 in zip(true_data_seq[key + "_seq"], loaded_data[key + "_seq"]): + self.assertArrayEqual(a1, a2) + self.assertEqual(true_data_ctx["x_byte"], loaded_data["x_byte"]) + self.assertListEqual(true_data_seq["x_byte_seq"], loaded_data["x_byte_seq"]) + self.assertEqual(len(expected_res), len(res_after_reset)) + for (true_data_ctx, true_data_seq), loaded_data in zip(expected_res, res_after_reset): + self.assertSetEqual(set(true_data_ctx.keys()).union(true_data_seq.keys()), set(loaded_data.keys())) + for key in ["x_float", "x_int"]: + self.assertArrayEqual(true_data_ctx[key].numpy(), loaded_data[key].numpy()) + self.assertEqual(len(true_data_seq[key + "_seq"]), len(loaded_data[key + "_seq"])) + self.assertIsInstance(loaded_data[key + "_seq"], list) + for a1, a2 in zip(true_data_seq[key + "_seq"], loaded_data[key + "_seq"]): + self.assertArrayEqual(a1, a2) + self.assertEqual(true_data_ctx["x_byte"], loaded_data["x_byte"]) + self.assertListEqual(true_data_seq["x_byte_seq"], loaded_data["x_byte_seq"]) + + # __len__ Test: length isn't implemented since it cannot be known ahead of time + with self.assertRaisesRegex(TypeError, "doesn't have valid length"): + len(tfrecord_parser) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index 35befa0df..d13cef0d3 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -102,6 +102,7 @@ TarArchiveLoaderIterDataPipe as TarArchiveLoader, TarArchiveReaderIterDataPipe as TarArchiveReader, ) +from torchdata.datapipes.iter.util.tfrecordloader import TFRecordLoaderIterDataPipe as TFRecordLoader from torchdata.datapipes.iter.util.unzipper import UnZipperIterDataPipe as UnZipper from torchdata.datapipes.iter.util.xzfileloader import ( XzFileLoaderIterDataPipe as XzFileLoader, @@ -171,6 +172,7 @@ "ShardingFilter", "Shuffler", "StreamReader", + "TFRecordLoader", "TarArchiveLoader", "TarArchiveReader", "UnBatcher", diff --git a/torchdata/datapipes/iter/util/protobuf_template/__init__.py b/torchdata/datapipes/iter/util/protobuf_template/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchdata/datapipes/iter/util/protobuf_template/_tfrecord_example_pb2.py b/torchdata/datapipes/iter/util/protobuf_template/_tfrecord_example_pb2.py new file mode 100644 index 000000000..972c61ff4 --- /dev/null +++ b/torchdata/datapipes/iter/util/protobuf_template/_tfrecord_example_pb2.py @@ -0,0 +1,699 @@ +# type: ignore +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: example.proto + +import sys + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) +from google.protobuf import ( + descriptor as _descriptor, + descriptor_pb2, + message as _message, + reflection as _reflection, + symbol_database as _symbol_database, +) + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor.FileDescriptor( + name="example.proto", + package="tfrecord", + syntax="proto3", + serialized_pb=_b( + '\n\rexample.proto\x12\x08tfrecord"\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c"\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01"\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01"\x92\x01\n\x07\x46\x65\x61ture\x12)\n\nbytes_list\x18\x01 \x01(\x0b\x32\x13.tfrecord.BytesListH\x00\x12)\n\nfloat_list\x18\x02 \x01(\x0b\x32\x13.tfrecord.FloatListH\x00\x12)\n\nint64_list\x18\x03 \x01(\x0b\x32\x13.tfrecord.Int64ListH\x00\x42\x06\n\x04kind"\x7f\n\x08\x46\x65\x61tures\x12\x30\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32\x1f.tfrecord.Features.FeatureEntry\x1a\x41\n\x0c\x46\x65\x61tureEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.tfrecord.Feature:\x02\x38\x01"1\n\x0b\x46\x65\x61tureList\x12"\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32\x11.tfrecord.Feature"\x98\x01\n\x0c\x46\x65\x61tureLists\x12=\n\x0c\x66\x65\x61ture_list\x18\x01 \x03(\x0b\x32\'.tfrecord.FeatureLists.FeatureListEntry\x1aI\n\x10\x46\x65\x61tureListEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tfrecord.FeatureList:\x02\x38\x01"/\n\x07\x45xample\x12$\n\x08\x66\x65\x61tures\x18\x01 \x01(\x0b\x32\x12.tfrecord.Features"e\n\x0fSequenceExample\x12#\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x12.tfrecord.Features\x12-\n\rfeature_lists\x18\x02 \x01(\x0b\x32\x16.tfrecord.FeatureListsB\x03\xf8\x01\x01\x62\x06proto3' + ), +) +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + + +_BYTESLIST = _descriptor.Descriptor( + name="BytesList", + full_name="tfrecord.BytesList", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="value", + full_name="tfrecord.BytesList.value", + index=0, + number=1, + type=12, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=27, + serialized_end=53, +) + + +_FLOATLIST = _descriptor.Descriptor( + name="FloatList", + full_name="tfrecord.FloatList", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="value", + full_name="tfrecord.FloatList.value", + index=0, + number=1, + type=2, + cpp_type=6, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=55, + serialized_end=85, +) + + +_INT64LIST = _descriptor.Descriptor( + name="Int64List", + full_name="tfrecord.Int64List", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="value", + full_name="tfrecord.Int64List.value", + index=0, + number=1, + type=3, + cpp_type=2, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=87, + serialized_end=117, +) + + +_FEATURE = _descriptor.Descriptor( + name="Feature", + full_name="tfrecord.Feature", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="bytes_list", + full_name="tfrecord.Feature.bytes_list", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="float_list", + full_name="tfrecord.Feature.float_list", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="int64_list", + full_name="tfrecord.Feature.int64_list", + index=2, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name="kind", full_name="tfrecord.Feature.kind", index=0, containing_type=None, fields=[] + ), + ], + serialized_start=120, + serialized_end=266, +) + + +_FEATURES_FEATUREENTRY = _descriptor.Descriptor( + name="FeatureEntry", + full_name="tfrecord.Features.FeatureEntry", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="key", + full_name="tfrecord.Features.FeatureEntry.key", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="value", + full_name="tfrecord.Features.FeatureEntry.value", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")), + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=330, + serialized_end=395, +) + +_FEATURES = _descriptor.Descriptor( + name="Features", + full_name="tfrecord.Features", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="feature", + full_name="tfrecord.Features.feature", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[ + _FEATURES_FEATUREENTRY, + ], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=268, + serialized_end=395, +) + + +_FEATURELIST = _descriptor.Descriptor( + name="FeatureList", + full_name="tfrecord.FeatureList", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="feature", + full_name="tfrecord.FeatureList.feature", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=397, + serialized_end=446, +) + + +_FEATURELISTS_FEATURELISTENTRY = _descriptor.Descriptor( + name="FeatureListEntry", + full_name="tfrecord.FeatureLists.FeatureListEntry", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="key", + full_name="tfrecord.FeatureLists.FeatureListEntry.key", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="value", + full_name="tfrecord.FeatureLists.FeatureListEntry.value", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")), + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=528, + serialized_end=601, +) + +_FEATURELISTS = _descriptor.Descriptor( + name="FeatureLists", + full_name="tfrecord.FeatureLists", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="feature_list", + full_name="tfrecord.FeatureLists.feature_list", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[ + _FEATURELISTS_FEATURELISTENTRY, + ], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=449, + serialized_end=601, +) + + +_EXAMPLE = _descriptor.Descriptor( + name="Example", + full_name="tfrecord.Example", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="features", + full_name="tfrecord.Example.features", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=603, + serialized_end=650, +) + + +_SEQUENCEEXAMPLE = _descriptor.Descriptor( + name="SequenceExample", + full_name="tfrecord.SequenceExample", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="context", + full_name="tfrecord.SequenceExample.context", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="feature_lists", + full_name="tfrecord.SequenceExample.feature_lists", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=652, + serialized_end=753, +) + +_FEATURE.fields_by_name["bytes_list"].message_type = _BYTESLIST +_FEATURE.fields_by_name["float_list"].message_type = _FLOATLIST +_FEATURE.fields_by_name["int64_list"].message_type = _INT64LIST +_FEATURE.oneofs_by_name["kind"].fields.append(_FEATURE.fields_by_name["bytes_list"]) +_FEATURE.fields_by_name["bytes_list"].containing_oneof = _FEATURE.oneofs_by_name["kind"] +_FEATURE.oneofs_by_name["kind"].fields.append(_FEATURE.fields_by_name["float_list"]) +_FEATURE.fields_by_name["float_list"].containing_oneof = _FEATURE.oneofs_by_name["kind"] +_FEATURE.oneofs_by_name["kind"].fields.append(_FEATURE.fields_by_name["int64_list"]) +_FEATURE.fields_by_name["int64_list"].containing_oneof = _FEATURE.oneofs_by_name["kind"] +_FEATURES_FEATUREENTRY.fields_by_name["value"].message_type = _FEATURE +_FEATURES_FEATUREENTRY.containing_type = _FEATURES +_FEATURES.fields_by_name["feature"].message_type = _FEATURES_FEATUREENTRY +_FEATURELIST.fields_by_name["feature"].message_type = _FEATURE +_FEATURELISTS_FEATURELISTENTRY.fields_by_name["value"].message_type = _FEATURELIST +_FEATURELISTS_FEATURELISTENTRY.containing_type = _FEATURELISTS +_FEATURELISTS.fields_by_name["feature_list"].message_type = _FEATURELISTS_FEATURELISTENTRY +_EXAMPLE.fields_by_name["features"].message_type = _FEATURES +_SEQUENCEEXAMPLE.fields_by_name["context"].message_type = _FEATURES +_SEQUENCEEXAMPLE.fields_by_name["feature_lists"].message_type = _FEATURELISTS +DESCRIPTOR.message_types_by_name["BytesList"] = _BYTESLIST +DESCRIPTOR.message_types_by_name["FloatList"] = _FLOATLIST +DESCRIPTOR.message_types_by_name["Int64List"] = _INT64LIST +DESCRIPTOR.message_types_by_name["Feature"] = _FEATURE +DESCRIPTOR.message_types_by_name["Features"] = _FEATURES +DESCRIPTOR.message_types_by_name["FeatureList"] = _FEATURELIST +DESCRIPTOR.message_types_by_name["FeatureLists"] = _FEATURELISTS +DESCRIPTOR.message_types_by_name["Example"] = _EXAMPLE +DESCRIPTOR.message_types_by_name["SequenceExample"] = _SEQUENCEEXAMPLE + +BytesList = _reflection.GeneratedProtocolMessageType( + "BytesList", + (_message.Message,), + dict( + DESCRIPTOR=_BYTESLIST, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.BytesList) + ), +) +_sym_db.RegisterMessage(BytesList) + +FloatList = _reflection.GeneratedProtocolMessageType( + "FloatList", + (_message.Message,), + dict( + DESCRIPTOR=_FLOATLIST, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.FloatList) + ), +) +_sym_db.RegisterMessage(FloatList) + +Int64List = _reflection.GeneratedProtocolMessageType( + "Int64List", + (_message.Message,), + dict( + DESCRIPTOR=_INT64LIST, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.Int64List) + ), +) +_sym_db.RegisterMessage(Int64List) + +Feature = _reflection.GeneratedProtocolMessageType( + "Feature", + (_message.Message,), + dict( + DESCRIPTOR=_FEATURE, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.Feature) + ), +) +_sym_db.RegisterMessage(Feature) + +Features = _reflection.GeneratedProtocolMessageType( + "Features", + (_message.Message,), + dict( + FeatureEntry=_reflection.GeneratedProtocolMessageType( + "FeatureEntry", + (_message.Message,), + dict( + DESCRIPTOR=_FEATURES_FEATUREENTRY, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.Features.FeatureEntry) + ), + ), + DESCRIPTOR=_FEATURES, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.Features) + ), +) +_sym_db.RegisterMessage(Features) +_sym_db.RegisterMessage(Features.FeatureEntry) + +FeatureList = _reflection.GeneratedProtocolMessageType( + "FeatureList", + (_message.Message,), + dict( + DESCRIPTOR=_FEATURELIST, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.FeatureList) + ), +) +_sym_db.RegisterMessage(FeatureList) + +FeatureLists = _reflection.GeneratedProtocolMessageType( + "FeatureLists", + (_message.Message,), + dict( + FeatureListEntry=_reflection.GeneratedProtocolMessageType( + "FeatureListEntry", + (_message.Message,), + dict( + DESCRIPTOR=_FEATURELISTS_FEATURELISTENTRY, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.FeatureLists.FeatureListEntry) + ), + ), + DESCRIPTOR=_FEATURELISTS, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.FeatureLists) + ), +) +_sym_db.RegisterMessage(FeatureLists) +_sym_db.RegisterMessage(FeatureLists.FeatureListEntry) + +Example = _reflection.GeneratedProtocolMessageType( + "Example", + (_message.Message,), + dict( + DESCRIPTOR=_EXAMPLE, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.Example) + ), +) +_sym_db.RegisterMessage(Example) + +SequenceExample = _reflection.GeneratedProtocolMessageType( + "SequenceExample", + (_message.Message,), + dict( + DESCRIPTOR=_SEQUENCEEXAMPLE, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.SequenceExample) + ), +) +_sym_db.RegisterMessage(SequenceExample) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b("\370\001\001")) +_FLOATLIST.fields_by_name["value"].has_options = True +_FLOATLIST.fields_by_name["value"]._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")) +_INT64LIST.fields_by_name["value"].has_options = True +_INT64LIST.fields_by_name["value"]._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")) +_FEATURES_FEATUREENTRY.has_options = True +_FEATURES_FEATUREENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")) +_FEATURELISTS_FEATURELISTENTRY.has_options = True +_FEATURELISTS_FEATURELISTENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")) +# @@protoc_insertion_point(module_scope) diff --git a/torchdata/datapipes/iter/util/protobuf_template/example.proto b/torchdata/datapipes/iter/util/protobuf_template/example.proto new file mode 100644 index 000000000..9f762fb51 --- /dev/null +++ b/torchdata/datapipes/iter/util/protobuf_template/example.proto @@ -0,0 +1,301 @@ +// Protocol messages for describing input data Examples for machine learning +// model training or inference. +syntax = "proto3"; + +package tensorflow; + +import "tensorflow/core/example/feature.proto"; + +option cc_enable_arenas = true; +option java_outer_classname = "ExampleProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.example"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/example/example_protos_go_proto"; + +// An Example is a mostly-normalized data format for storing data for +// training and inference. It contains a key-value store (features); where +// each key (string) maps to a Feature message (which is oneof packed BytesList, +// FloatList, or Int64List). This flexible and compact format allows the +// storage of large amounts of typed data, but requires that the data shape +// and use be determined by the configuration files and parsers that are used to +// read and write this format. That is, the Example is mostly *not* a +// self-describing format. In TensorFlow, Examples are read in row-major +// format, so any configuration that describes data with rank-2 or above +// should keep this in mind. If you flatten a matrix into a FloatList it should +// be stored as [ row 0 ... row 1 ... row M-1 ] +// +// An Example for a movie recommendation application: +// features { +// feature { +// key: "age" +// value { float_list { +// value: 29.0 +// }} +// } +// feature { +// key: "movie" +// value { bytes_list { +// value: "The Shawshank Redemption" +// value: "Fight Club" +// }} +// } +// feature { +// key: "movie_ratings" +// value { float_list { +// value: 9.0 +// value: 9.7 +// }} +// } +// feature { +// key: "suggestion" +// value { bytes_list { +// value: "Inception" +// }} +// } +// # Note that this feature exists to be used as a label in training. +// # E.g., if training a logistic regression model to predict purchase +// # probability in our learning tool we would set the label feature to +// # "suggestion_purchased". +// feature { +// key: "suggestion_purchased" +// value { float_list { +// value: 1.0 +// }} +// } +// # Similar to "suggestion_purchased" above this feature exists to be used +// # as a label in training. +// # E.g., if training a linear regression model to predict purchase +// # price in our learning tool we would set the label feature to +// # "purchase_price". +// feature { +// key: "purchase_price" +// value { float_list { +// value: 9.99 +// }} +// } +// } +// +// A conformant Example data set obeys the following conventions: +// - If a Feature K exists in one example with data type T, it must be of +// type T in all other examples when present. It may be omitted. +// - The number of instances of Feature K list data may vary across examples, +// depending on the requirements of the model. +// - If a Feature K doesn't exist in an example, a K-specific default will be +// used, if configured. +// - If a Feature K exists in an example but contains no items, the intent +// is considered to be an empty tensor and no default will be used. + +message Example { + Features features = 1; +} + +// A SequenceExample is an Example representing one or more sequences, and +// some context. The context contains features which apply to the entire +// example. The feature_lists contain a key, value map where each key is +// associated with a repeated set of Features (a FeatureList). +// A FeatureList thus represents the values of a feature identified by its key +// over time / frames. +// +// Below is a SequenceExample for a movie recommendation application recording a +// sequence of ratings by a user. The time-independent features ("locale", +// "age", "favorites") describing the user are part of the context. The sequence +// of movies the user rated are part of the feature_lists. For each movie in the +// sequence we have information on its name and actors and the user's rating. +// This information is recorded in three separate feature_list(s). +// In the example below there are only two movies. All three feature_list(s), +// namely "movie_ratings", "movie_names", and "actors" have a feature value for +// both movies. Note, that "actors" is itself a bytes_list with multiple +// strings per movie. +// +// context: { +// feature: { +// key : "locale" +// value: { +// bytes_list: { +// value: [ "pt_BR" ] +// } +// } +// } +// feature: { +// key : "age" +// value: { +// float_list: { +// value: [ 19.0 ] +// } +// } +// } +// feature: { +// key : "favorites" +// value: { +// bytes_list: { +// value: [ "Majesty Rose", "Savannah Outen", "One Direction" ] +// } +// } +// } +// } +// feature_lists: { +// feature_list: { +// key : "movie_ratings" +// value: { +// feature: { +// float_list: { +// value: [ 4.5 ] +// } +// } +// feature: { +// float_list: { +// value: [ 5.0 ] +// } +// } +// } +// } +// feature_list: { +// key : "movie_names" +// value: { +// feature: { +// bytes_list: { +// value: [ "The Shawshank Redemption" ] +// } +// } +// feature: { +// bytes_list: { +// value: [ "Fight Club" ] +// } +// } +// } +// } +// feature_list: { +// key : "actors" +// value: { +// feature: { +// bytes_list: { +// value: [ "Tim Robbins", "Morgan Freeman" ] +// } +// } +// feature: { +// bytes_list: { +// value: [ "Brad Pitt", "Edward Norton", "Helena Bonham Carter" ] +// } +// } +// } +// } +// } +// +// A conformant SequenceExample data set obeys the following conventions: +// +// Context: +// - All conformant context features K must obey the same conventions as +// a conformant Example's features (see above). +// Feature lists: +// - A FeatureList L may be missing in an example; it is up to the +// parser configuration to determine if this is allowed or considered +// an empty list (zero length). +// - If a FeatureList L exists, it may be empty (zero length). +// - If a FeatureList L is non-empty, all features within the FeatureList +// must have the same data type T. Even across SequenceExamples, the type T +// of the FeatureList identified by the same key must be the same. An entry +// without any values may serve as an empty feature. +// - If a FeatureList L is non-empty, it is up to the parser configuration +// to determine if all features within the FeatureList must +// have the same size. The same holds for this FeatureList across multiple +// examples. +// - For sequence modeling, e.g.: +// http://colah.github.io/posts/2015-08-Understanding-LSTMs/ +// https://github.com/tensorflow/nmt +// the feature lists represent a sequence of frames. +// In this scenario, all FeatureLists in a SequenceExample have the same +// number of Feature messages, so that the ith element in each FeatureList +// is part of the ith frame (or time step). +// Examples of conformant and non-conformant examples' FeatureLists: +// +// Conformant FeatureLists: +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0 ] } } } +// } } +// +// Non-conformant FeatureLists (mismatched types): +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { int64_list: { value: [ 5 ] } } } +// } } +// +// Conditionally conformant FeatureLists, the parser configuration determines +// if the feature sizes must match: +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0, 6.0 ] } } } +// } } +// +// Conformant pair of SequenceExample +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0 ] } } } +// } } +// and: +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0 ] } } +// feature: { float_list: { value: [ 2.0 ] } } } +// } } +// +// Conformant pair of SequenceExample +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0 ] } } } +// } } +// and: +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { } +// } } +// +// Conditionally conformant pair of SequenceExample, the parser configuration +// determines if the second feature_lists is consistent (zero-length) or +// invalid (missing "movie_ratings"): +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0 ] } } } +// } } +// and: +// feature_lists: { } +// +// Non-conformant pair of SequenceExample (mismatched types) +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0 ] } } } +// } } +// and: +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { int64_list: { value: [ 4 ] } } +// feature: { int64_list: { value: [ 5 ] } } +// feature: { int64_list: { value: [ 2 ] } } } +// } } +// +// Conditionally conformant pair of SequenceExample; the parser configuration +// determines if the feature sizes must match: +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.5 ] } } +// feature: { float_list: { value: [ 5.0 ] } } } +// } } +// and: +// feature_lists: { feature_list: { +// key: "movie_ratings" +// value: { feature: { float_list: { value: [ 4.0 ] } } +// feature: { float_list: { value: [ 5.0, 3.0 ] } } +// } } + +message SequenceExample { + Features context = 1; + FeatureLists feature_lists = 2; +} diff --git a/torchdata/datapipes/iter/util/protobuf_template/feature.proto b/torchdata/datapipes/iter/util/protobuf_template/feature.proto new file mode 100644 index 000000000..7f9fad982 --- /dev/null +++ b/torchdata/datapipes/iter/util/protobuf_template/feature.proto @@ -0,0 +1,110 @@ +// Protocol messages for describing features for machine learning model +// training or inference. +// +// There are three base Feature types: +// - bytes +// - float +// - int64 +// +// A Feature contains Lists which may hold zero or more values. These +// lists are the base values BytesList, FloatList, Int64List. +// +// Features are organized into categories by name. The Features message +// contains the mapping from name to Feature. +// +// Example Features for a movie recommendation application: +// feature { +// key: "age" +// value { float_list { +// value: 29.0 +// }} +// } +// feature { +// key: "movie" +// value { bytes_list { +// value: "The Shawshank Redemption" +// value: "Fight Club" +// }} +// } +// feature { +// key: "movie_ratings" +// value { float_list { +// value: 9.0 +// value: 9.7 +// }} +// } +// feature { +// key: "suggestion" +// value { bytes_list { +// value: "Inception" +// }} +// } +// feature { +// key: "suggestion_purchased" +// value { int64_list { +// value: 1 +// }} +// } +// feature { +// key: "purchase_price" +// value { float_list { +// value: 9.99 +// }} +// } +// + +syntax = "proto3"; + +package tensorflow; + +option cc_enable_arenas = true; +option java_outer_classname = "FeatureProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.example"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/example/example_protos_go_proto"; + +// LINT.IfChange +// Containers to hold repeated fundamental values. +message BytesList { + repeated bytes value = 1; +} +message FloatList { + repeated float value = 1 [packed = true]; +} +message Int64List { + repeated int64 value = 1 [packed = true]; +} + +// Containers for non-sequential data. +message Feature { + // Each feature can be exactly one kind. + oneof kind { + BytesList bytes_list = 1; + FloatList float_list = 2; + Int64List int64_list = 3; + } +} + +message Features { + // Map from feature name to feature. + map feature = 1; +} + +// Containers for sequential data. +// +// A FeatureList contains lists of Features. These may hold zero or more +// Feature values. +// +// FeatureLists are organized into categories by name. The FeatureLists message +// contains the mapping from name to FeatureList. +// +message FeatureList { + repeated Feature feature = 1; +} + +message FeatureLists { + // Map from feature name to feature list. + map feature_list = 1; +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/python/training/training.py) diff --git a/torchdata/datapipes/iter/util/tfrecordloader.py b/torchdata/datapipes/iter/util/tfrecordloader.py new file mode 100644 index 000000000..3fe3a1627 --- /dev/null +++ b/torchdata/datapipes/iter/util/tfrecordloader.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import struct +import warnings +from functools import partial +from io import BufferedIOBase +from typing import Any, cast, Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union + +import torch + +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterDataPipe + +from torchdata.datapipes.utils.common import validate_pathname_binary_tuple + +try: + from math import prod # type: ignore +except ImportError: + # Implementation for older Python + # NOTE: this is not supported by mypy yet + # https://github.com/python/mypy/issues/1393 + import operator + from functools import reduce + + def prod(xs): + return reduce(operator.mul, xs, 1) + + +try: + import google.protobuf as _protobuf + + del _protobuf + HAS_PROTOBUF = True +except ImportError: + HAS_PROTOBUF = False + +U = Union[bytes, bytearray, str] +FeatureSpec = Tuple[Tuple[int, ...], torch.dtype] +ExampleSpec = Dict[str, FeatureSpec] + +# Note, reccursive types not supported by mypy at the moment +# TODO: uncomment as soon as it becomes supported +# https://github.com/python/mypy/issues/731 +# BinaryData = Union[str, List['BinaryData']] +BinaryData = Union[str, List[str], List[List[str]], List[List[List[Any]]]] +ExampleFeature = Union[torch.Tensor, List[torch.Tensor], BinaryData] +Example = Dict[str, ExampleFeature] + + +class SequenceExampleSpec(NamedTuple): + context: ExampleSpec + feature_lists: ExampleSpec + + +def _assert_protobuf() -> None: + if not HAS_PROTOBUF: + raise ModuleNotFoundError( + "Package `protobuf` is required to be installed to use this datapipe." + "Please use `pip install protobuf` or `conda install -c conda-forge protobuf`" + "to install the package" + ) + + +def iterate_tfrecord_file(data: BufferedIOBase) -> Iterator[memoryview]: + length_bytes = bytearray(8) + crc_bytes = bytearray(4) + data_bytes = bytearray(1024) + + while True: + bytes_read = data.readinto(length_bytes) + if bytes_read == 0: + break + elif bytes_read != 8: + raise RuntimeError("Invalid tfrecord file: failed to read the record size.") + if data.readinto(crc_bytes) != 4: + raise RuntimeError("Invalid tfrecord file: failed to read the start token.") + (length,) = struct.unpack(" len(data_bytes): + data_bytes = data_bytes.zfill(int(length * 1.5)) + data_bytes_view = memoryview(data_bytes)[:length] + if data.readinto(data_bytes_view) != length: + raise RuntimeError("Invalid tfrecord file: failed to read the record.") + if data.readinto(crc_bytes) != 4: + raise RuntimeError("Invalid tfrecord file: failed to read the end token.") + + # TODO: check CRC + yield data_bytes_view + + +def process_feature(feature) -> torch.Tensor: + # NOTE: We assume that each key in the example has only one field + # (either "bytes_list", "float_list", or "int64_list")! + field = feature.ListFields()[0] + inferred_typename, value = field[0].name, field[1].value + if inferred_typename == "bytes_list": + pass + elif inferred_typename == "float_list": + value = torch.tensor(value, dtype=torch.float32) + elif inferred_typename == "int64_list": + value = torch.tensor(value, dtype=torch.int64) + return value + + +def _reshape_list(value, shape): + # Flatten list + flat_list = [] + + def flatten(value): + if isinstance(value, (str, bytes)): + flat_list.append(value) + else: + for x in value: + flatten(x) + + flatten(value) + + # Compute correct shape + common_divisor = prod(x for x in shape if x != -1) + if sum(1 for x in shape if x == -1) > 1: + raise RuntimeError("Shape can contain at most one dynamic dimension (-1).") + if len(flat_list) % max(common_divisor, 1) != 0: + raise RuntimeError(f"Cannot reshape {len(flat_list)} values into shape {shape}") + shape = [x if x != -1 else (len(flat_list) // common_divisor) for x in shape] + + # Reshape list into the correct shape + def _reshape(value, shape): + if len(shape) == 0: + assert len(value) == 1 + return value[0] + elif len(shape) == 1: # To make the reccursion faster + assert len(value) == shape[0] + return value + dim_size = len(value) // shape[0] + return [_reshape(value[i * dim_size : (i + 1) * dim_size], shape[1:]) for i in range(dim_size)] + + return _reshape(flat_list, shape) + + +def _apply_feature_spec(value, feature_spec): + if feature_spec is not None: + shape, dtype = feature_spec + if isinstance(dtype, torch.dtype): + if shape is not None: + value = value.reshape(shape) + value = value.to(dtype) + elif shape is not None: + # Manual list reshape + value = _reshape_list(value, shape) + return value + + +def _parse_tfrecord_features(features, spec: Optional[ExampleSpec]) -> Dict[str, torch.Tensor]: + result = dict() + features = features.feature + for key in features.keys(): + if spec is not None and key not in spec: + continue + feature_spec = None if spec is None else spec[key] + feature = features[key] + result[key] = _apply_feature_spec(process_feature(feature), feature_spec) + return result + + +def parse_tfrecord_sequence_example(example, spec: Optional[ExampleSpec]) -> Example: + # Parse context features + result = cast(Example, _parse_tfrecord_features(example.context, spec)) + + # Parse feature lists + feature_lists_keys = None if spec is None else set(spec.keys()) - set(result.keys()) + features = example.feature_lists.feature_list + for key in features.keys(): + if feature_lists_keys is not None and key not in feature_lists_keys: + continue + feature_spec = None if spec is None else spec[key] + feature = features[key].feature + if key in result: + raise RuntimeError( + "TFRecord example's key {key} is contained in both the context and feature lists. This is not supported." + ) + + value: Union[torch.Tensor, List[Any]] = list(map(partial(process_feature), feature)) + + # For known torch dtypes, we stack the list features + if feature_spec is not None and isinstance(feature_spec[1], torch.dtype): + value = torch.stack(cast(List[torch.Tensor], value), 0) + value = _apply_feature_spec(value, feature_spec) + result[key] = value + if spec is not None and len(result.keys()) != len(spec.keys()): + raise RuntimeError(f"Example is missing some required keys: {sorted(result.keys())} != {sorted(spec.keys())}") + return result + + +@functional_datapipe("load_from_tfrecord") +class TFRecordLoaderIterDataPipe(IterDataPipe[Example]): + r""" + Opens/decompresses tfrecord binary streams from an Iterable DataPipe which contains tuples of path name and + tfrecord binary stream, and yields the stored records (functional name: ``load_from_tfrecord``). + + Args: + datapipe: Iterable DataPipe that provides tuples of path name and tfrecord binary stream + length: a nominal length of the DataPipe + + Note: + The opened file handles will be closed automatically if the default ``DecoderDataPipe`` + is attached. Otherwise, user should be responsible to close file handles explicitly + or let Python's GC close them periodically. + + Example: + >>> from torchdata.datapipes.iter import FileLister, FileOpener + >>> datapipe1 = FileLister(".", "*.tfrecord") + >>> datapipe2 = FileOpener(datapipe1, mode="b") + >>> tfrecord_loader_dp = datapipe2.load_from_tfrecord() + >>> for example in tfrecord_loader_dp: + >>> print(example) + """ + + def __init__( + self, datapipe: Iterable[Tuple[str, BufferedIOBase]], spec: Optional[ExampleSpec] = None, length: int = -1 + ) -> None: + super().__init__() + _assert_protobuf() + + self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe + self.length: int = length + self.spec = spec + + def __iter__(self) -> Iterator[Example]: + # We assume that the "example.proto" and "feature.proto" + # stays the same for future TensorFlow versions. + # If it changed, newer TensorFlow versions would + # not be able to load older tfrecord datasets. + from .protobuf_template import _tfrecord_example_pb2 as example_pb2 + + for data in self.datapipe: + validate_pathname_binary_tuple(data) + pathname, data_stream = data + try: + for example_bytes in iterate_tfrecord_file(data_stream): + example = example_pb2.SequenceExample() # type: ignore + example.ParseFromString(example_bytes) # type: ignore + yield parse_tfrecord_sequence_example(example, self.spec) + except RuntimeError as e: + warnings.warn(f"Unable to read from corrupted tfrecord stream {pathname} due to: {e}, abort!") + raise e + + def __len__(self) -> int: + if self.length == -1: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + return self.length