From 6c773bde4a2efbc3db01e13dd152b8c2fbed8f00 Mon Sep 17 00:00:00 2001 From: Jonas Kulhanek Date: Sat, 19 Mar 2022 08:08:44 +0100 Subject: [PATCH] Add tfrecord_loader implementation --- .github/workflows/_build_test_upload.yml | 2 +- .github/workflows/ci.yml | 2 +- test/test_tfrecord.py | 355 +++++++++ torchdata/datapipes/iter/__init__.py | 2 + .../iter/util/_tfrecord_example_pb2.py | 698 ++++++++++++++++++ .../datapipes/iter/util/tfrecordloader.py | 225 ++++++ 6 files changed, 1282 insertions(+), 2 deletions(-) create mode 100644 test/test_tfrecord.py create mode 100644 torchdata/datapipes/iter/util/_tfrecord_example_pb2.py create mode 100644 torchdata/datapipes/iter/util/tfrecordloader.py diff --git a/.github/workflows/_build_test_upload.yml b/.github/workflows/_build_test_upload.yml index 4879c31fd..ec4cf8877 100644 --- a/.github/workflows/_build_test_upload.yml +++ b/.github/workflows/_build_test_upload.yml @@ -150,7 +150,7 @@ jobs: done - name: Install Test Requirements if: steps.trigger_build.outputs.value == 'true' - run: pip3 install expecttest fsspec iopath==0.1.9 numpy pytest rarfile + run: pip3 install expecttest fsspec iopath==0.1.9 numpy pytest rarfile tensorflow - name: Run DataPipes Tests with pytest if: steps.trigger_build.outputs.value == 'true' run: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b69739a6f..e19f546c6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,7 +47,7 @@ jobs: pip3 install -r requirements.txt pip3 install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - name: Install test requirements - run: pip3 install expecttest fsspec iopath==0.1.9 numpy pytest rarfile + run: pip3 install expecttest fsspec iopath==0.1.9 numpy pytest rarfile tensorflow - name: Build TorchData run: python setup.py develop - name: Run DataPipes tests with pytest diff --git a/test/test_tfrecord.py b/test/test_tfrecord.py new file mode 100644 index 000000000..da60ea8a1 --- /dev/null +++ b/test/test_tfrecord.py @@ -0,0 +1,355 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os +import unittest +import warnings +from functools import partial + +import expecttest +import numpy as np + +import torch + +from _utils._common_utils_for_test import create_temp_dir, reset_after_n_next_calls +from torchdata.datapipes.iter import ( + FileLister, + FileOpener, + FSSpecFileLister, + FSSpecFileOpener, + FSSpecSaver, + IterableWrapper, + TFRecordLoader, +) + +try: + import tensorflow as tf + + HAS_TF = True +except ImportError: + HAS_TF = False +skipIfNoTF = unittest.skipIf(not HAS_TF, "no tensorflow") + + +def create_temp_tfrecord_files(temp_dir: str): + with tf.io.TFRecordWriter(os.path.join(temp_dir, "example.tfrecord")) as writer: + for _ in range(4): + x = tf.random.uniform( + [ + 10, + ] + ) + + record_bytes = tf.train.Example( + features=tf.train.Features( + feature={ + "x_float": tf.train.Feature(float_list=tf.train.FloatList(value=x)), + "x_int": tf.train.Feature(int64_list=tf.train.Int64List(value=tf.cast(x * 10, "int64"))), + "x_byte": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"test str"])), + } + ) + ).SerializeToString() + writer.write(record_bytes) + + with tf.io.TFRecordWriter(os.path.join(temp_dir, "sequence_example.tfrecord")) as writer: + for _ in range(4): + x = tf.random.uniform( + [ + 10, + ] + ) + rep = int( + tf.random.uniform( + [ + 1, + ] + ).numpy()[0] + * 10 + + 1 + ) + + record_bytes = tf.train.SequenceExample( + context=tf.train.Features( + feature={ + "x_float": tf.train.Feature(float_list=tf.train.FloatList(value=x)), + "x_int": tf.train.Feature(int64_list=tf.train.Int64List(value=tf.cast(x * 10, "int64"))), + "x_byte": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"test str"])), + } + ), + feature_lists=tf.train.FeatureLists( + feature_list={ + "x_float_seq": tf.train.FeatureList( + feature=[tf.train.Feature(float_list=tf.train.FloatList(value=x))] * rep + ), + "x_int_seq": tf.train.FeatureList( + feature=[tf.train.Feature(int64_list=tf.train.Int64List(value=tf.cast(x * 10, "int64")))] + * rep + ), + "x_byte_seq": tf.train.FeatureList( + feature=[tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"test str"]))] * rep + ), + } + ), + ).SerializeToString() + writer.write(record_bytes) + + +class TestDataPipeTFRecord(expecttest.TestCase): + def setUp(self): + self.temp_dir = create_temp_dir() + self.temp_files = create_temp_tfrecord_files(self.temp_dir.name) + + def tearDown(self): + try: + self.temp_dir.cleanup() + except Exception as e: + warnings.warn(f"TestDataPipeTFRecord was not able to cleanup temp dir due to {e}") + + def assertArrayEqual(self, arr1, arr2): + np.testing.assert_array_equal(arr1, arr2) + + @skipIfNoTF + @torch.no_grad() + def test_tfrecord_loader_example_iterdatapipe(self): + filename = f"{self.temp_dir.name}/example.tfrecord" + datapipe1 = IterableWrapper([filename]) + datapipe2 = FileOpener(datapipe1, mode="b") + + # Functional Test: test if the returned data is correct + tfrecord_parser = datapipe2.load_from_tfrecord() + result = list(tfrecord_parser) + self.assertEqual(len(result), 4) + decode_fn = partial( + tf.io.parse_single_example, + features={ + "x_float": tf.io.FixedLenFeature([10], tf.float32), + "x_int": tf.io.FixedLenFeature([10], tf.int64), + "x_byte": tf.io.FixedLenFeature([], tf.string), + }, + ) + expected_res = final_expected_res = list(tf.data.TFRecordDataset([filename]).map(decode_fn)) + for true_data, loaded_data in zip(expected_res, result): + self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) + for key in ["x_float", "x_int"]: + self.assertArrayEqual(true_data[key].numpy(), loaded_data[key].numpy()) + self.assertEqual(len(loaded_data["x_byte"]), 1) + self.assertEqual(true_data["x_byte"].numpy(), loaded_data["x_byte"][0]) + + # Functional Test: test if the shape of the returned data is correct when using spec + tfrecord_parser = datapipe2.load_from_tfrecord( + { + "x_float": ((5, 2), torch.float64), + "x_int": ((5, 2), torch.int32), + "x_byte": (tuple(), None), + } + ) + result = list(tfrecord_parser) + self.assertEqual(len(result), 4) + decode_fn = partial( + tf.io.parse_single_example, + features={ + "x_float": tf.io.FixedLenFeature([5, 2], tf.float32), + "x_int": tf.io.FixedLenFeature([5, 2], tf.int64), + "x_byte": tf.io.FixedLenFeature([], tf.string), + }, + ) + expected_res = list(tf.data.TFRecordDataset([filename]).map(decode_fn)) + for true_data, loaded_data in zip(expected_res, result): + self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) + self.assertArrayEqual(true_data["x_float"].numpy(), loaded_data["x_float"].float().numpy()) + self.assertArrayEqual(true_data["x_int"].numpy(), loaded_data["x_int"].long().numpy()) + self.assertEqual(loaded_data["x_float"].dtype, torch.float64) + self.assertEqual(loaded_data["x_int"].dtype, torch.int32) + self.assertEqual(true_data["x_byte"].numpy(), loaded_data["x_byte"]) + + # Functional Test: ignore features missing from spec + tfrecord_parser = datapipe2.load_from_tfrecord( + { + "x_float": ((10,), torch.float32), + } + ) + result = list(tfrecord_parser) + self.assertEqual(len(result), 4) + decode_fn = partial( + tf.io.parse_single_example, + features={ + "x_float": tf.io.FixedLenFeature([10], tf.float32), + }, + ) + expected_res = list(tf.data.TFRecordDataset([filename]).map(decode_fn)) + for true_data, loaded_data in zip(expected_res, result): + self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) + self.assertArrayEqual(true_data["x_float"].numpy(), loaded_data["x_float"].float().numpy()) + + # Functional Test: raises error if missing spec feature + with self.assertRaises(RuntimeError): + tfrecord_parser = datapipe2.load_from_tfrecord( + { + "x_float_unknown": ((5, 2), torch.float64), + "x_int": ((5, 2), torch.int32), + "x_byte": (tuple(), None), + } + ) + result = list(tfrecord_parser) + + # Reset Test: + tfrecord_parser = TFRecordLoader(datapipe2) + expected_res = final_expected_res + n_elements_before_reset = 2 + res_before_reset, res_after_reset = reset_after_n_next_calls(tfrecord_parser, n_elements_before_reset) + self.assertEqual(len(expected_res[:n_elements_before_reset]), len(res_before_reset)) + for true_data, loaded_data in zip(expected_res[:n_elements_before_reset], res_before_reset): + self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) + for key in ["x_float", "x_int"]: + self.assertArrayEqual(true_data[key].numpy(), loaded_data[key].numpy()) + self.assertEqual(true_data["x_byte"].numpy(), loaded_data["x_byte"][0]) + self.assertEqual(len(expected_res), len(res_after_reset)) + for true_data, loaded_data in zip(expected_res, res_after_reset): + self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) + for key in ["x_float", "x_int"]: + self.assertArrayEqual(true_data[key].numpy(), loaded_data[key].numpy()) + self.assertEqual(true_data["x_byte"].numpy(), loaded_data["x_byte"][0]) + + # __len__ Test: length isn't implemented since it cannot be known ahead of time + with self.assertRaisesRegex(TypeError, "doesn't have valid length"): + len(tfrecord_parser) + + @skipIfNoTF + @torch.no_grad() + def test_tfrecord_loader_sequence_example_iterdatapipe(self): + filename = f"{self.temp_dir.name}/sequence_example.tfrecord" + datapipe1 = IterableWrapper([filename]) + datapipe2 = FileOpener(datapipe1, mode="b") + + # Functional Test: test if the returned data is correct + tfrecord_parser = datapipe2.load_from_tfrecord() + result = list(tfrecord_parser) + self.assertEqual(len(result), 4) + decode_fn = partial( + tf.io.parse_single_sequence_example, + context_features={ + "x_float": tf.io.FixedLenFeature([10], tf.float32), + "x_int": tf.io.FixedLenFeature([10], tf.int64), + "x_byte": tf.io.FixedLenFeature([1], tf.string), + }, + sequence_features={ + "x_float_seq": tf.io.RaggedFeature(tf.float32), + "x_int_seq": tf.io.RaggedFeature(tf.int64), + "x_byte_seq": tf.io.RaggedFeature(tf.string), + }, + ) + expected_res = final_expected_res = list(tf.data.TFRecordDataset([filename]).map(decode_fn)) + for (true_data_ctx, true_data_seq), loaded_data in zip(expected_res, result): + self.assertSetEqual(set(true_data_ctx.keys()).union(true_data_seq.keys()), set(loaded_data.keys())) + for key in ["x_float", "x_int"]: + self.assertArrayEqual(true_data_ctx[key].numpy(), loaded_data[key].numpy()) + self.assertEqual(true_data_seq[key + "_seq"].to_tensor().shape[0], len(loaded_data[key + "_seq"])) + self.assertIsInstance(loaded_data[key + "_seq"], list) + for a1, a2 in zip(true_data_seq[key + "_seq"], loaded_data[key + "_seq"]): + self.assertArrayEqual(a1, a2) + self.assertEqual(true_data_ctx["x_byte"].numpy(), loaded_data["x_byte"]) + self.assertListEqual(list(true_data_seq["x_byte_seq"].to_tensor().numpy()), loaded_data["x_byte_seq"]) + + # Functional Test: test if the shape of the returned data is correct when using spec + tfrecord_parser = datapipe2.load_from_tfrecord( + { + "x_float": ((5, 2), torch.float64), + "x_int": ((5, 2), torch.int32), + "x_byte": (tuple(), None), + "x_float_seq": ((-1, 5, 2), torch.float64), + "x_int_seq": ((-1, 5, 2), torch.int32), + "x_byte_seq": ((-1,), None), + } + ) + result = list(tfrecord_parser) + self.assertEqual(len(result), 4) + decode_fn = partial( + tf.io.parse_single_sequence_example, + context_features={ + "x_float": tf.io.FixedLenFeature([5, 2], tf.float32), + "x_int": tf.io.FixedLenFeature([5, 2], tf.int64), + "x_byte": tf.io.FixedLenFeature([], tf.string), + }, + sequence_features={ + "x_float_seq": tf.io.RaggedFeature(tf.float32), + "x_int_seq": tf.io.RaggedFeature(tf.int64), + "x_byte_seq": tf.io.RaggedFeature(tf.string), + }, + ) + expected_res = list(tf.data.TFRecordDataset([filename]).map(decode_fn)) + for (true_data_ctx, true_data_seq), loaded_data in zip(expected_res, result): + self.assertSetEqual(set(true_data_ctx.keys()).union(true_data_seq.keys()), set(loaded_data.keys())) + for key in ["x_float", "x_int"]: + l_loaded_data = loaded_data[key] + if key == "x_float": + l_loaded_data = l_loaded_data.float() + else: + l_loaded_data = l_loaded_data.int() + self.assertArrayEqual(true_data_ctx[key].numpy(), l_loaded_data.numpy()) + self.assertArrayEqual( + tf.reshape(true_data_seq[key + "_seq"].to_tensor(), [-1, 5, 2]), loaded_data[key + "_seq"] + ) + self.assertEqual(true_data_ctx["x_byte"].numpy(), loaded_data["x_byte"]) + self.assertListEqual(list(true_data_seq["x_byte_seq"].to_tensor().numpy()), loaded_data["x_byte_seq"]) + + # Functional Test: ignore features missing from spec + tfrecord_parser = datapipe2.load_from_tfrecord( + { + "x_float": ((10,), torch.float32), + } + ) + result = list(tfrecord_parser) + self.assertEqual(len(result), 4) + decode_fn = partial( + tf.io.parse_single_example, + features={ + "x_float": tf.io.FixedLenFeature([10], tf.float32), + }, + ) + expected_res = list(tf.data.TFRecordDataset([filename]).map(decode_fn)) + for true_data, loaded_data in zip(expected_res, result): + self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) + self.assertArrayEqual(true_data["x_float"].numpy(), loaded_data["x_float"].float().numpy()) + + # Functional Test: raises error if missing spec feature + with self.assertRaises(RuntimeError): + tfrecord_parser = datapipe2.load_from_tfrecord( + {"x_float_unknown": ((5, 2), torch.float64), "x_int": ((5, 2), torch.int32), "x_byte": None} + ) + result = list(tfrecord_parser) + + # Reset Test: + tfrecord_parser = TFRecordLoader(datapipe2) + expected_res = final_expected_res + n_elements_before_reset = 2 + res_before_reset, res_after_reset = reset_after_n_next_calls(tfrecord_parser, n_elements_before_reset) + self.assertEqual(len(expected_res[:n_elements_before_reset]), len(res_before_reset)) + for (true_data_ctx, true_data_seq), loaded_data in zip( + expected_res[:n_elements_before_reset], res_before_reset + ): + self.assertSetEqual(set(true_data_ctx.keys()).union(true_data_seq.keys()), set(loaded_data.keys())) + for key in ["x_float", "x_int"]: + self.assertArrayEqual(true_data_ctx[key].numpy(), loaded_data[key].numpy()) + self.assertEqual(true_data_seq[key + "_seq"].to_tensor().shape[0], len(loaded_data[key + "_seq"])) + self.assertIsInstance(loaded_data[key + "_seq"], list) + for a1, a2 in zip(true_data_seq[key + "_seq"], loaded_data[key + "_seq"]): + self.assertArrayEqual(a1, a2) + self.assertEqual(true_data_ctx["x_byte"].numpy(), loaded_data["x_byte"]) + self.assertListEqual(list(true_data_seq["x_byte_seq"].to_tensor().numpy()), loaded_data["x_byte_seq"]) + self.assertEqual(len(expected_res), len(res_after_reset)) + for (true_data_ctx, true_data_seq), loaded_data in zip(expected_res, res_after_reset): + self.assertSetEqual(set(true_data_ctx.keys()).union(true_data_seq.keys()), set(loaded_data.keys())) + for key in ["x_float", "x_int"]: + self.assertArrayEqual(true_data_ctx[key].numpy(), loaded_data[key].numpy()) + self.assertEqual(true_data_seq[key + "_seq"].to_tensor().shape[0], len(loaded_data[key + "_seq"])) + self.assertIsInstance(loaded_data[key + "_seq"], list) + for a1, a2 in zip(true_data_seq[key + "_seq"], loaded_data[key + "_seq"]): + self.assertArrayEqual(a1, a2) + self.assertEqual(true_data_ctx["x_byte"].numpy(), loaded_data["x_byte"]) + self.assertListEqual(list(true_data_seq["x_byte_seq"].to_tensor().numpy()), loaded_data["x_byte_seq"]) + + # __len__ Test: length isn't implemented since it cannot be known ahead of time + with self.assertRaisesRegex(TypeError, "doesn't have valid length"): + len(tfrecord_parser) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index aecd0862e..6ec0079a5 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -83,6 +83,7 @@ TarArchiveLoaderIterDataPipe as TarArchiveLoader, TarArchiveReaderIterDataPipe as TarArchiveReader, ) +from torchdata.datapipes.iter.util.tfrecordloader import TFRecordLoaderIterDataPipe as TFRecordLoader from torchdata.datapipes.iter.util.unzipper import UnZipperIterDataPipe as UnZipper from torchdata.datapipes.iter.util.xzfileloader import ( XzFileLoaderIterDataPipe as XzFileLoader, @@ -152,6 +153,7 @@ "ShardingFilter", "Shuffler", "StreamReader", + "TFRecordLoader", "TarArchiveLoader", "TarArchiveReader", "UnBatcher", diff --git a/torchdata/datapipes/iter/util/_tfrecord_example_pb2.py b/torchdata/datapipes/iter/util/_tfrecord_example_pb2.py new file mode 100644 index 000000000..bb799ec07 --- /dev/null +++ b/torchdata/datapipes/iter/util/_tfrecord_example_pb2.py @@ -0,0 +1,698 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: example.proto + +import sys + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) +from google.protobuf import ( + descriptor as _descriptor, + descriptor_pb2, + message as _message, + reflection as _reflection, + symbol_database as _symbol_database, +) + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor.FileDescriptor( + name="example.proto", + package="tfrecord", + syntax="proto3", + serialized_pb=_b( + '\n\rexample.proto\x12\x08tfrecord"\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c"\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01"\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01"\x92\x01\n\x07\x46\x65\x61ture\x12)\n\nbytes_list\x18\x01 \x01(\x0b\x32\x13.tfrecord.BytesListH\x00\x12)\n\nfloat_list\x18\x02 \x01(\x0b\x32\x13.tfrecord.FloatListH\x00\x12)\n\nint64_list\x18\x03 \x01(\x0b\x32\x13.tfrecord.Int64ListH\x00\x42\x06\n\x04kind"\x7f\n\x08\x46\x65\x61tures\x12\x30\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32\x1f.tfrecord.Features.FeatureEntry\x1a\x41\n\x0c\x46\x65\x61tureEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.tfrecord.Feature:\x02\x38\x01"1\n\x0b\x46\x65\x61tureList\x12"\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32\x11.tfrecord.Feature"\x98\x01\n\x0c\x46\x65\x61tureLists\x12=\n\x0c\x66\x65\x61ture_list\x18\x01 \x03(\x0b\x32\'.tfrecord.FeatureLists.FeatureListEntry\x1aI\n\x10\x46\x65\x61tureListEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tfrecord.FeatureList:\x02\x38\x01"/\n\x07\x45xample\x12$\n\x08\x66\x65\x61tures\x18\x01 \x01(\x0b\x32\x12.tfrecord.Features"e\n\x0fSequenceExample\x12#\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x12.tfrecord.Features\x12-\n\rfeature_lists\x18\x02 \x01(\x0b\x32\x16.tfrecord.FeatureListsB\x03\xf8\x01\x01\x62\x06proto3' + ), +) +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + + +_BYTESLIST = _descriptor.Descriptor( + name="BytesList", + full_name="tfrecord.BytesList", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="value", + full_name="tfrecord.BytesList.value", + index=0, + number=1, + type=12, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=27, + serialized_end=53, +) + + +_FLOATLIST = _descriptor.Descriptor( + name="FloatList", + full_name="tfrecord.FloatList", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="value", + full_name="tfrecord.FloatList.value", + index=0, + number=1, + type=2, + cpp_type=6, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=55, + serialized_end=85, +) + + +_INT64LIST = _descriptor.Descriptor( + name="Int64List", + full_name="tfrecord.Int64List", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="value", + full_name="tfrecord.Int64List.value", + index=0, + number=1, + type=3, + cpp_type=2, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=87, + serialized_end=117, +) + + +_FEATURE = _descriptor.Descriptor( + name="Feature", + full_name="tfrecord.Feature", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="bytes_list", + full_name="tfrecord.Feature.bytes_list", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="float_list", + full_name="tfrecord.Feature.float_list", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="int64_list", + full_name="tfrecord.Feature.int64_list", + index=2, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name="kind", full_name="tfrecord.Feature.kind", index=0, containing_type=None, fields=[] + ), + ], + serialized_start=120, + serialized_end=266, +) + + +_FEATURES_FEATUREENTRY = _descriptor.Descriptor( + name="FeatureEntry", + full_name="tfrecord.Features.FeatureEntry", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="key", + full_name="tfrecord.Features.FeatureEntry.key", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="value", + full_name="tfrecord.Features.FeatureEntry.value", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")), + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=330, + serialized_end=395, +) + +_FEATURES = _descriptor.Descriptor( + name="Features", + full_name="tfrecord.Features", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="feature", + full_name="tfrecord.Features.feature", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[ + _FEATURES_FEATUREENTRY, + ], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=268, + serialized_end=395, +) + + +_FEATURELIST = _descriptor.Descriptor( + name="FeatureList", + full_name="tfrecord.FeatureList", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="feature", + full_name="tfrecord.FeatureList.feature", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=397, + serialized_end=446, +) + + +_FEATURELISTS_FEATURELISTENTRY = _descriptor.Descriptor( + name="FeatureListEntry", + full_name="tfrecord.FeatureLists.FeatureListEntry", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="key", + full_name="tfrecord.FeatureLists.FeatureListEntry.key", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="value", + full_name="tfrecord.FeatureLists.FeatureListEntry.value", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")), + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=528, + serialized_end=601, +) + +_FEATURELISTS = _descriptor.Descriptor( + name="FeatureLists", + full_name="tfrecord.FeatureLists", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="feature_list", + full_name="tfrecord.FeatureLists.feature_list", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[ + _FEATURELISTS_FEATURELISTENTRY, + ], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=449, + serialized_end=601, +) + + +_EXAMPLE = _descriptor.Descriptor( + name="Example", + full_name="tfrecord.Example", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="features", + full_name="tfrecord.Example.features", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=603, + serialized_end=650, +) + + +_SEQUENCEEXAMPLE = _descriptor.Descriptor( + name="SequenceExample", + full_name="tfrecord.SequenceExample", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="context", + full_name="tfrecord.SequenceExample.context", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="feature_lists", + full_name="tfrecord.SequenceExample.feature_lists", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=652, + serialized_end=753, +) + +_FEATURE.fields_by_name["bytes_list"].message_type = _BYTESLIST +_FEATURE.fields_by_name["float_list"].message_type = _FLOATLIST +_FEATURE.fields_by_name["int64_list"].message_type = _INT64LIST +_FEATURE.oneofs_by_name["kind"].fields.append(_FEATURE.fields_by_name["bytes_list"]) +_FEATURE.fields_by_name["bytes_list"].containing_oneof = _FEATURE.oneofs_by_name["kind"] +_FEATURE.oneofs_by_name["kind"].fields.append(_FEATURE.fields_by_name["float_list"]) +_FEATURE.fields_by_name["float_list"].containing_oneof = _FEATURE.oneofs_by_name["kind"] +_FEATURE.oneofs_by_name["kind"].fields.append(_FEATURE.fields_by_name["int64_list"]) +_FEATURE.fields_by_name["int64_list"].containing_oneof = _FEATURE.oneofs_by_name["kind"] +_FEATURES_FEATUREENTRY.fields_by_name["value"].message_type = _FEATURE +_FEATURES_FEATUREENTRY.containing_type = _FEATURES +_FEATURES.fields_by_name["feature"].message_type = _FEATURES_FEATUREENTRY +_FEATURELIST.fields_by_name["feature"].message_type = _FEATURE +_FEATURELISTS_FEATURELISTENTRY.fields_by_name["value"].message_type = _FEATURELIST +_FEATURELISTS_FEATURELISTENTRY.containing_type = _FEATURELISTS +_FEATURELISTS.fields_by_name["feature_list"].message_type = _FEATURELISTS_FEATURELISTENTRY +_EXAMPLE.fields_by_name["features"].message_type = _FEATURES +_SEQUENCEEXAMPLE.fields_by_name["context"].message_type = _FEATURES +_SEQUENCEEXAMPLE.fields_by_name["feature_lists"].message_type = _FEATURELISTS +DESCRIPTOR.message_types_by_name["BytesList"] = _BYTESLIST +DESCRIPTOR.message_types_by_name["FloatList"] = _FLOATLIST +DESCRIPTOR.message_types_by_name["Int64List"] = _INT64LIST +DESCRIPTOR.message_types_by_name["Feature"] = _FEATURE +DESCRIPTOR.message_types_by_name["Features"] = _FEATURES +DESCRIPTOR.message_types_by_name["FeatureList"] = _FEATURELIST +DESCRIPTOR.message_types_by_name["FeatureLists"] = _FEATURELISTS +DESCRIPTOR.message_types_by_name["Example"] = _EXAMPLE +DESCRIPTOR.message_types_by_name["SequenceExample"] = _SEQUENCEEXAMPLE + +BytesList = _reflection.GeneratedProtocolMessageType( + "BytesList", + (_message.Message,), + dict( + DESCRIPTOR=_BYTESLIST, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.BytesList) + ), +) +_sym_db.RegisterMessage(BytesList) + +FloatList = _reflection.GeneratedProtocolMessageType( + "FloatList", + (_message.Message,), + dict( + DESCRIPTOR=_FLOATLIST, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.FloatList) + ), +) +_sym_db.RegisterMessage(FloatList) + +Int64List = _reflection.GeneratedProtocolMessageType( + "Int64List", + (_message.Message,), + dict( + DESCRIPTOR=_INT64LIST, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.Int64List) + ), +) +_sym_db.RegisterMessage(Int64List) + +Feature = _reflection.GeneratedProtocolMessageType( + "Feature", + (_message.Message,), + dict( + DESCRIPTOR=_FEATURE, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.Feature) + ), +) +_sym_db.RegisterMessage(Feature) + +Features = _reflection.GeneratedProtocolMessageType( + "Features", + (_message.Message,), + dict( + FeatureEntry=_reflection.GeneratedProtocolMessageType( + "FeatureEntry", + (_message.Message,), + dict( + DESCRIPTOR=_FEATURES_FEATUREENTRY, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.Features.FeatureEntry) + ), + ), + DESCRIPTOR=_FEATURES, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.Features) + ), +) +_sym_db.RegisterMessage(Features) +_sym_db.RegisterMessage(Features.FeatureEntry) + +FeatureList = _reflection.GeneratedProtocolMessageType( + "FeatureList", + (_message.Message,), + dict( + DESCRIPTOR=_FEATURELIST, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.FeatureList) + ), +) +_sym_db.RegisterMessage(FeatureList) + +FeatureLists = _reflection.GeneratedProtocolMessageType( + "FeatureLists", + (_message.Message,), + dict( + FeatureListEntry=_reflection.GeneratedProtocolMessageType( + "FeatureListEntry", + (_message.Message,), + dict( + DESCRIPTOR=_FEATURELISTS_FEATURELISTENTRY, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.FeatureLists.FeatureListEntry) + ), + ), + DESCRIPTOR=_FEATURELISTS, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.FeatureLists) + ), +) +_sym_db.RegisterMessage(FeatureLists) +_sym_db.RegisterMessage(FeatureLists.FeatureListEntry) + +Example = _reflection.GeneratedProtocolMessageType( + "Example", + (_message.Message,), + dict( + DESCRIPTOR=_EXAMPLE, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.Example) + ), +) +_sym_db.RegisterMessage(Example) + +SequenceExample = _reflection.GeneratedProtocolMessageType( + "SequenceExample", + (_message.Message,), + dict( + DESCRIPTOR=_SEQUENCEEXAMPLE, + __module__="example_pb2" + # @@protoc_insertion_point(class_scope:tfrecord.SequenceExample) + ), +) +_sym_db.RegisterMessage(SequenceExample) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b("\370\001\001")) +_FLOATLIST.fields_by_name["value"].has_options = True +_FLOATLIST.fields_by_name["value"]._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")) +_INT64LIST.fields_by_name["value"].has_options = True +_INT64LIST.fields_by_name["value"]._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")) +_FEATURES_FEATUREENTRY.has_options = True +_FEATURES_FEATUREENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")) +_FEATURELISTS_FEATURELISTENTRY.has_options = True +_FEATURELISTS_FEATURELISTENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")) +# @@protoc_insertion_point(module_scope) diff --git a/torchdata/datapipes/iter/util/tfrecordloader.py b/torchdata/datapipes/iter/util/tfrecordloader.py new file mode 100644 index 000000000..6f455d57d --- /dev/null +++ b/torchdata/datapipes/iter/util/tfrecordloader.py @@ -0,0 +1,225 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import struct +import warnings +from functools import partial +from io import BufferedIOBase +from math import prod +from typing import Dict, Iterable, Iterator, NamedTuple, Tuple, Union + +import torch + +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterDataPipe + +from torchdata.datapipes.utils.common import validate_pathname_binary_tuple + +try: + import google.protobuf as _protobuf + + del _protobuf + HAS_PROTOBUF = True +except ImportError: + HAS_PROTOBUF = False + +U = Union[bytes, bytearray, str] +FeatureSpec = Tuple[Tuple[int, ...], torch.dtype] +ExampleSpec = Dict[str, FeatureSpec] + + +class SequenceExampleSpec(NamedTuple): + context: ExampleSpec + feature_lists: ExampleSpec + + +def _assert_protobuf() -> None: + if not HAS_PROTOBUF: + raise ModuleNotFoundError( + "Package `protobuf` is required to be installed to use this datapipe." + "Please use `pip install protobuf` or `conda install -c conda-forge protobuf`" + "to install the package" + ) + + +def iterate_tfrecord_file(data: BufferedIOBase) -> Iterator[memoryview]: + length_bytes = bytearray(8) + crc_bytes = bytearray(4) + data_bytes = bytearray(1024) + + while True: + bytes_read = data.readinto(length_bytes) + if bytes_read == 0: + break + elif bytes_read != 8: + raise RuntimeError("Invalid tfrecord file: failed to read the record size.") + if data.readinto(crc_bytes) != 4: + raise RuntimeError("Invalid tfrecord file: failed to read the start token.") + (length,) = struct.unpack(" len(data_bytes): + data_bytes = data_bytes.zfill(int(length * 1.5)) + data_bytes_view = memoryview(data_bytes)[:length] + if data.readinto(data_bytes_view) != length: + raise RuntimeError("Invalid tfrecord file: failed to read the record.") + if data.readinto(crc_bytes) != 4: + raise RuntimeError("Invalid tfrecord file: failed to read the end token.") + + # TODO: check CRC + yield data_bytes_view + + +def process_feature(feature) -> torch.Tensor: + # NOTE: We assume that each key in the example has only one field + # (either "bytes_list", "float_list", or "int64_list")! + field = feature.ListFields()[0] + inferred_typename, value = field[0].name, field[1].value + if inferred_typename == "bytes_list": + pass + elif inferred_typename == "float_list": + value = torch.tensor(value, dtype=torch.float32) + elif inferred_typename == "int64_list": + value = torch.tensor(value, dtype=torch.int64) + return value + + +def _reshape_list(value, shape): + # Flatten list + flat_list = [] + + def flatten(value): + if isinstance(value, (str, bytes)): + flat_list.append(value) + else: + for x in value: + flatten(x) + + flatten(value) + + # Compute correct shape + common_divisor = prod(x for x in shape if x != -1) + if sum(1 for x in shape if x == -1) > 1: + raise RuntimeError("Shape can contain at most one dynamic dimension (-1).") + if len(flat_list) % max(common_divisor, 1) != 0: + raise RuntimeError(f"Cannot reshape {len(flat_list)} values into shape {shape}") + shape = [x if x != -1 else (len(flat_list) // common_divisor) for x in shape] + + # Reshape list into the correct shape + def _reshape(value, shape): + if len(shape) == 0: + assert len(value) == 1 + return value[0] + elif len(shape) == 1: # To make the reccursion faster + assert len(value) == shape[0] + return value + dim_size = len(value) // shape[0] + return [_reshape(value[i * dim_size : (i + 1) * dim_size], shape[1:]) for i in range(dim_size)] + + return _reshape(flat_list, shape) + + +def _apply_feature_spec(value, feature_spec): + if feature_spec is not None: + shape, dtype = feature_spec + if isinstance(dtype, torch.dtype): + if shape is not None: + value = value.reshape(shape) + value = value.to(dtype) + elif shape is not None: + # Manual list reshape + value = _reshape_list(value, shape) + return value + + +def _parse_tfrecord_features(features, spec: ExampleSpec) -> Dict[str, torch.Tensor]: + result = dict() + features = features.feature + for key in features.keys(): + if spec is not None and key not in spec: + continue + feature_spec = None if spec is None else spec[key] + feature = features[key] + result[key] = _apply_feature_spec(process_feature(feature), feature_spec) + return result + + +def parse_tfrecord_sequence_example(example, spec: ExampleSpec) -> Dict[str, torch.Tensor]: + # Parse context features + result = _parse_tfrecord_features(example.context, spec) + + # Parse feature lists + feature_lists_keys = None if spec is None else set(spec.keys()) - set(result.keys()) + features = example.feature_lists.feature_list + for key in features.keys(): + if feature_lists_keys is not None and key not in feature_lists_keys: + continue + feature_spec = None if spec is None else spec[key] + feature = features[key].feature + if key in result: + raise RuntimeError( + "TFRecord example's key {key} is contained in both the context and feature lists. This is not supported." + ) + + value = list(map(partial(process_feature), feature)) + + # For known torch dtypes, we stack the list features + if feature_spec is not None and isinstance(feature_spec[1], torch.dtype): + value = torch.stack(value, 0) + value = _apply_feature_spec(value, feature_spec) + result[key] = value + if spec is not None and len(result.keys()) != len(spec.keys()): + raise RuntimeError(f"Example is missing some required keys: {sorted(result.keys())} != {sorted(spec.keys())}") + return result + + +@functional_datapipe("load_from_tfrecord") +class TFRecordLoaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]): + r""" + Opens/decompresses tfrecord binary streams from an Iterable DataPipe which contains tuples of path name and + tfrecord binary stream, and yields the stored records (functional name: ``load_from_tfrecord``). + + Args: + datapipe: Iterable DataPipe that provides tuples of path name and tfrecord binary stream + length: a nominal length of the DataPipe + + Note: + The opened file handles will be closed automatically if the default ``DecoderDataPipe`` + is attached. Otherwise, user should be responsible to close file handles explicitly + or let Python's GC close them periodically. + + Example: + >>> from torchdata.datapipes.iter import FileLister, FileOpener + >>> datapipe1 = FileLister(".", "*.tfrecord") + >>> datapipe2 = FileOpener(datapipe1, mode="b") + >>> tfrecord_loader_dp = datapipe2.load_from_tar() + >>> for example in tfrecord_loader_dp: + >>> print(example) + b'0123456789abcdef' + """ + + def __init__( + self, datapipe: Iterable[Tuple[str, BufferedIOBase]], spec: ExampleSpec = None, length: int = -1 + ) -> None: + super().__init__() + _assert_protobuf() + + self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe + self.length: int = length + self.spec = spec + + def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: + from . import _tfrecord_example_pb2 as example_pb2 + + for data in self.datapipe: + validate_pathname_binary_tuple(data) + pathname, data_stream = data + try: + for example_bytes in iterate_tfrecord_file(data_stream): + example = example_pb2.SequenceExample() + example.ParseFromString(example_bytes) + yield parse_tfrecord_sequence_example(example, self.spec) + except RuntimeError as e: + warnings.warn(f"Unable to read from corrupted tfrecord stream {pathname} due to: {e}, abort!") + raise e + + def __len__(self) -> int: + if self.length == -1: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + return self.length