diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index d13cef0d3..ab34604f6 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -102,7 +102,11 @@ TarArchiveLoaderIterDataPipe as TarArchiveLoader, TarArchiveReaderIterDataPipe as TarArchiveReader, ) -from torchdata.datapipes.iter.util.tfrecordloader import TFRecordLoaderIterDataPipe as TFRecordLoader +from torchdata.datapipes.iter.util.tfrecordloader import ( + TFRecordExample, + TFRecordExampleSpec, + TFRecordLoaderIterDataPipe as TFRecordLoader, +) from torchdata.datapipes.iter.util.unzipper import UnZipperIterDataPipe as UnZipper from torchdata.datapipes.iter.util.xzfileloader import ( XzFileLoaderIterDataPipe as XzFileLoader, diff --git a/torchdata/datapipes/iter/util/tfrecordloader.py b/torchdata/datapipes/iter/util/tfrecordloader.py index 3fe3a1627..e293acbbb 100644 --- a/torchdata/datapipes/iter/util/tfrecordloader.py +++ b/torchdata/datapipes/iter/util/tfrecordloader.py @@ -39,21 +39,21 @@ def prod(xs): HAS_PROTOBUF = False U = Union[bytes, bytearray, str] -FeatureSpec = Tuple[Tuple[int, ...], torch.dtype] -ExampleSpec = Dict[str, FeatureSpec] +TFRecordFeatureSpec = Tuple[Tuple[int, ...], torch.dtype] +TFRecordExampleSpec = Dict[str, TFRecordFeatureSpec] # Note, reccursive types not supported by mypy at the moment # TODO: uncomment as soon as it becomes supported # https://github.com/python/mypy/issues/731 # BinaryData = Union[str, List['BinaryData']] -BinaryData = Union[str, List[str], List[List[str]], List[List[List[Any]]]] -ExampleFeature = Union[torch.Tensor, List[torch.Tensor], BinaryData] -Example = Dict[str, ExampleFeature] +TFRecordBinaryData = Union[str, List[str], List[List[str]], List[List[List[Any]]]] +TFRecordExampleFeature = Union[torch.Tensor, List[torch.Tensor], TFRecordBinaryData] +TFRecordExample = Dict[str, TFRecordExampleFeature] class SequenceExampleSpec(NamedTuple): - context: ExampleSpec - feature_lists: ExampleSpec + context: TFRecordExampleSpec + feature_lists: TFRecordExampleSpec def _assert_protobuf() -> None: @@ -153,7 +153,7 @@ def _apply_feature_spec(value, feature_spec): return value -def _parse_tfrecord_features(features, spec: Optional[ExampleSpec]) -> Dict[str, torch.Tensor]: +def _parse_tfrecord_features(features, spec: Optional[TFRecordExampleSpec]) -> Dict[str, torch.Tensor]: result = dict() features = features.feature for key in features.keys(): @@ -165,9 +165,9 @@ def _parse_tfrecord_features(features, spec: Optional[ExampleSpec]) -> Dict[str, return result -def parse_tfrecord_sequence_example(example, spec: Optional[ExampleSpec]) -> Example: +def parse_tfrecord_sequence_example(example, spec: Optional[TFRecordExampleSpec]) -> TFRecordExample: # Parse context features - result = cast(Example, _parse_tfrecord_features(example.context, spec)) + result = cast(TFRecordExample, _parse_tfrecord_features(example.context, spec)) # Parse feature lists feature_lists_keys = None if spec is None else set(spec.keys()) - set(result.keys()) @@ -195,7 +195,7 @@ def parse_tfrecord_sequence_example(example, spec: Optional[ExampleSpec]) -> Exa @functional_datapipe("load_from_tfrecord") -class TFRecordLoaderIterDataPipe(IterDataPipe[Example]): +class TFRecordLoaderIterDataPipe(IterDataPipe[TFRecordExample]): r""" Opens/decompresses tfrecord binary streams from an Iterable DataPipe which contains tuples of path name and tfrecord binary stream, and yields the stored records (functional name: ``load_from_tfrecord``). @@ -219,7 +219,10 @@ class TFRecordLoaderIterDataPipe(IterDataPipe[Example]): """ def __init__( - self, datapipe: Iterable[Tuple[str, BufferedIOBase]], spec: Optional[ExampleSpec] = None, length: int = -1 + self, + datapipe: Iterable[Tuple[str, BufferedIOBase]], + spec: Optional[TFRecordExampleSpec] = None, + length: int = -1, ) -> None: super().__init__() _assert_protobuf() @@ -228,7 +231,7 @@ def __init__( self.length: int = length self.spec = spec - def __iter__(self) -> Iterator[Example]: + def __iter__(self) -> Iterator[TFRecordExample]: # We assume that the "example.proto" and "feature.proto" # stays the same for future TensorFlow versions. # If it changed, newer TensorFlow versions would