From 6d2fba2988ac7f4cee36552dd4865ea6325df7a2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 20 May 2022 17:12:56 +0100 Subject: [PATCH] Add output_format do video datasets and readers --- torchvision/datasets/hmdb51.py | 6 +++++- torchvision/datasets/kinetics.py | 14 +++++++++----- torchvision/datasets/ucf101.py | 6 +++++- torchvision/datasets/video_utils.py | 10 ++++++++++ torchvision/io/video.py | 12 +++++++++++- 5 files changed, 40 insertions(+), 8 deletions(-) diff --git a/torchvision/datasets/hmdb51.py b/torchvision/datasets/hmdb51.py index 5bfb604c916..f7341f4aa30 100644 --- a/torchvision/datasets/hmdb51.py +++ b/torchvision/datasets/hmdb51.py @@ -37,11 +37,13 @@ class HMDB51(VisionDataset): otherwise from the ``test`` split. transform (callable, optional): A function/transform that takes in a TxHxWxC video and returns a transformed version. + output_format (str, optional): The format of the output video tensors (before transforms). + Can be either "THWC" (default) or "TCHW". Returns: tuple: A 3-tuple with the following entries: - - video (Tensor[T, H, W, C]): The `T` video frames + - video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points - label (int): class of the video clip @@ -71,6 +73,7 @@ def __init__( _video_height: int = 0, _video_min_dimension: int = 0, _audio_samples: int = 0, + output_format: str = "THWC", ) -> None: super().__init__(root) if fold not in (1, 2, 3): @@ -96,6 +99,7 @@ def __init__( _video_height=_video_height, _video_min_dimension=_video_min_dimension, _audio_samples=_audio_samples, + output_format=output_format, ) # we bookkeep the full version of video clips because we want to be able # to return the meta data of full version rather than the subset version of diff --git a/torchvision/datasets/kinetics.py b/torchvision/datasets/kinetics.py index 937cee495e0..b3c94c50de1 100644 --- a/torchvision/datasets/kinetics.py +++ b/torchvision/datasets/kinetics.py @@ -62,11 +62,14 @@ class Kinetics(VisionDataset): download (bool): Download the official version of the dataset to root folder. num_workers (int): Use multiple workers for VideoClips creation num_download_workers (int): Use multiprocessing in order to speed up download. + output_format (str, optional): The format of the output video tensors (before transforms). + Can be either "THWC" or "TCHW" (default). + Note that in most other utils and datasets, the default is actually "THWC". Returns: tuple: A 3-tuple with the following entries: - - video (Tensor[T, C, H, W]): the `T` video frames in torch.uint8 tensor + - video (Tensor[T, C, H, W] or Tensor[T, H, W, C]): the `T` video frames in torch.uint8 tensor - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points in torch.float tensor - label (int): class of the video clip @@ -106,6 +109,7 @@ def __init__( _audio_samples: int = 0, _audio_channels: int = 0, _legacy: bool = False, + output_format: str = "TCHW", ) -> None: # TODO: support test @@ -115,10 +119,12 @@ def __init__( self.root = root self._legacy = _legacy + if _legacy: print("Using legacy structure") self.split_folder = root self.split = "unknown" + output_format = "THWC" if download: raise ValueError("Cannot download the videos using legacy_structure.") else: @@ -145,6 +151,7 @@ def __init__( _video_min_dimension=_video_min_dimension, _audio_samples=_audio_samples, _audio_channels=_audio_channels, + output_format=output_format, ) self.transform = transform @@ -233,9 +240,6 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]: video, audio, info, video_idx = self.video_clips.get_clip(idx) - if not self._legacy: - # [T,H,W,C] --> [T,C,H,W] - video = video.permute(0, 3, 1, 2) label = self.samples[video_idx][1] if self.transform is not None: @@ -308,7 +312,7 @@ def __init__( warnings.warn( "The Kinetics400 class is deprecated since 0.12 and will be removed in 0.14." "Please use Kinetics(..., num_classes='400') instead." - "Note that Kinetics(..., num_classes='400') returns video in a more logical Tensor[T, C, H, W] format." + "Note that Kinetics(..., num_classes='400') returns video in a Tensor[T, C, H, W] format." ) if any(value is not None for value in (num_classes, split, download, num_download_workers)): raise RuntimeError( diff --git a/torchvision/datasets/ucf101.py b/torchvision/datasets/ucf101.py index caf849860e4..4ee5f1f3df9 100644 --- a/torchvision/datasets/ucf101.py +++ b/torchvision/datasets/ucf101.py @@ -38,11 +38,13 @@ class UCF101(VisionDataset): otherwise from the ``test`` split. transform (callable, optional): A function/transform that takes in a TxHxWxC video and returns a transformed version. + output_format (str, optional): The format of the output video tensors (before transforms). + Can be either "THWC" (default) or "TCHW". Returns: tuple: A 3-tuple with the following entries: - - video (Tensor[T, H, W, C]): the `T` video frames + - video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points - label (int): class of the video clip @@ -64,6 +66,7 @@ def __init__( _video_height: int = 0, _video_min_dimension: int = 0, _audio_samples: int = 0, + output_format: str = "THWC", ) -> None: super().__init__(root) if not 1 <= fold <= 3: @@ -87,6 +90,7 @@ def __init__( _video_height=_video_height, _video_min_dimension=_video_min_dimension, _audio_samples=_audio_samples, + output_format=output_format, ) # we bookkeep the full version of video clips because we want to be able # to return the meta data of full version rather than the subset version of diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index d444496ffe7..3fdd50d19c7 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -99,6 +99,7 @@ class VideoClips: on the resampled video num_workers (int): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) + output_format (str): The format of the output video tensors. Can be either "THWC" (default) or "TCHW". """ def __init__( @@ -115,6 +116,7 @@ def __init__( _video_max_dimension: int = 0, _audio_samples: int = 0, _audio_channels: int = 0, + output_format: str = "THWC", ) -> None: self.video_paths = video_paths @@ -127,6 +129,9 @@ def __init__( self._video_max_dimension = _video_max_dimension self._audio_samples = _audio_samples self._audio_channels = _audio_channels + self.output_format = output_format.upper() + if self.output_format not in ("THWC", "TCHW"): + raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") if _precomputed_metadata is None: self._compute_frame_pts() @@ -366,6 +371,11 @@ def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any] video = video[resampling_idx] info["video_fps"] = self.frame_rate assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}" + + if self.output_format == "TCHW": + # [T,H,W,C] --> [T,C,H,W] + video = video.permute(0, 3, 1, 2) + return video, audio, info, video_idx def __getstate__(self) -> Dict[str, Any]: diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 1c758661164..ceb20fe52c0 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -239,6 +239,7 @@ def read_video( start_pts: Union[float, Fraction] = 0, end_pts: Optional[Union[float, Fraction]] = None, pts_unit: str = "pts", + output_format: str = "THWC", ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: """ Reads a video from a file, returning both the video frames as well as @@ -252,15 +253,20 @@ def read_video( The end presentation time pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'pts'. + output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW". Returns: - vframes (Tensor[T, H, W, C]): the `T` video frames + vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(read_video) + output_format = output_format.upper() + if output_format not in ("THWC", "TCHW"): + raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") + from torchvision import get_video_backend if not os.path.exists(filename): @@ -334,6 +340,10 @@ def read_video( else: aframes = torch.empty((1, 0), dtype=torch.float32) + if output_format == "TCHW": + # [T,H,W,C] --> [T,C,H,W] + vframes = vframes.permute(0, 3, 1, 2) + return vframes, aframes, info