Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add output_format do video datasets and readers #6061

Merged
merged 2 commits into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion torchvision/datasets/hmdb51.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@ class HMDB51(VisionDataset):
otherwise from the ``test`` split.
transform (callable, optional): A function/transform that takes in a TxHxWxC video
and returns a transformed version.
output_format (str, optional): The format of the output video tensors (before transforms).
Can be either "THWC" (default) or "TCHW".

Returns:
tuple: A 3-tuple with the following entries:

- video (Tensor[T, H, W, C]): The `T` video frames
- video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
- audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
and `L` is the number of points
- label (int): class of the video clip
Expand Down Expand Up @@ -71,6 +73,7 @@ def __init__(
_video_height: int = 0,
_video_min_dimension: int = 0,
_audio_samples: int = 0,
output_format: str = "THWC",
) -> None:
super().__init__(root)
if fold not in (1, 2, 3):
Expand All @@ -96,6 +99,7 @@ def __init__(
_video_height=_video_height,
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
output_format=output_format,
)
# we bookkeep the full version of video clips because we want to be able
# to return the meta data of full version rather than the subset version of
Expand Down
14 changes: 9 additions & 5 deletions torchvision/datasets/kinetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,14 @@ class Kinetics(VisionDataset):
download (bool): Download the official version of the dataset to root folder.
num_workers (int): Use multiple workers for VideoClips creation
num_download_workers (int): Use multiprocessing in order to speed up download.
output_format (str, optional): The format of the output video tensors (before transforms).
Can be either "THWC" or "TCHW" (default).
Note that in most other utils and datasets, the default is actually "THWC".

Returns:
tuple: A 3-tuple with the following entries:

- video (Tensor[T, C, H, W]): the `T` video frames in torch.uint8 tensor
- video (Tensor[T, C, H, W] or Tensor[T, H, W, C]): the `T` video frames in torch.uint8 tensor
- audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
and `L` is the number of points in torch.float tensor
- label (int): class of the video clip
Expand Down Expand Up @@ -106,6 +109,7 @@ def __init__(
_audio_samples: int = 0,
_audio_channels: int = 0,
_legacy: bool = False,
output_format: str = "TCHW",
) -> None:

# TODO: support test
Expand All @@ -115,10 +119,12 @@ def __init__(

self.root = root
self._legacy = _legacy

if _legacy:
print("Using legacy structure")
self.split_folder = root
self.split = "unknown"
output_format = "THWC"
if download:
raise ValueError("Cannot download the videos using legacy_structure.")
else:
Expand All @@ -145,6 +151,7 @@ def __init__(
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
_audio_channels=_audio_channels,
output_format=output_format,
)
self.transform = transform

Expand Down Expand Up @@ -233,9 +240,6 @@ def __len__(self) -> int:

def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
video, audio, info, video_idx = self.video_clips.get_clip(idx)
if not self._legacy:
# [T,H,W,C] --> [T,C,H,W]
video = video.permute(0, 3, 1, 2)
label = self.samples[video_idx][1]

if self.transform is not None:
Expand Down Expand Up @@ -308,7 +312,7 @@ def __init__(
warnings.warn(
"The Kinetics400 class is deprecated since 0.12 and will be removed in 0.14."
"Please use Kinetics(..., num_classes='400') instead."
"Note that Kinetics(..., num_classes='400') returns video in a more logical Tensor[T, C, H, W] format."
"Note that Kinetics(..., num_classes='400') returns video in a Tensor[T, C, H, W] format."
)
if any(value is not None for value in (num_classes, split, download, num_download_workers)):
raise RuntimeError(
Expand Down
6 changes: 5 additions & 1 deletion torchvision/datasets/ucf101.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ class UCF101(VisionDataset):
otherwise from the ``test`` split.
transform (callable, optional): A function/transform that takes in a TxHxWxC video
and returns a transformed version.
output_format (str, optional): The format of the output video tensors (before transforms).
Can be either "THWC" (default) or "TCHW".

Returns:
tuple: A 3-tuple with the following entries:

- video (Tensor[T, H, W, C]): the `T` video frames
- video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
- audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
and `L` is the number of points
- label (int): class of the video clip
Expand All @@ -64,6 +66,7 @@ def __init__(
_video_height: int = 0,
_video_min_dimension: int = 0,
_audio_samples: int = 0,
output_format: str = "THWC",
) -> None:
super().__init__(root)
if not 1 <= fold <= 3:
Expand All @@ -87,6 +90,7 @@ def __init__(
_video_height=_video_height,
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
output_format=output_format,
)
# we bookkeep the full version of video clips because we want to be able
# to return the meta data of full version rather than the subset version of
Expand Down
10 changes: 10 additions & 0 deletions torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class VideoClips:
on the resampled video
num_workers (int): how many subprocesses to use for data loading.
0 means that the data will be loaded in the main process. (default: 0)
output_format (str): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
"""

def __init__(
Expand All @@ -115,6 +116,7 @@ def __init__(
_video_max_dimension: int = 0,
_audio_samples: int = 0,
_audio_channels: int = 0,
output_format: str = "THWC",
) -> None:

self.video_paths = video_paths
Expand All @@ -127,6 +129,9 @@ def __init__(
self._video_max_dimension = _video_max_dimension
self._audio_samples = _audio_samples
self._audio_channels = _audio_channels
self.output_format = output_format.upper()
if self.output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")

if _precomputed_metadata is None:
self._compute_frame_pts()
Expand Down Expand Up @@ -366,6 +371,11 @@ def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]
video = video[resampling_idx]
info["video_fps"] = self.frame_rate
assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"

if self.output_format == "TCHW":
# [T,H,W,C] --> [T,C,H,W]
video = video.permute(0, 3, 1, 2)

return video, audio, info, video_idx

def __getstate__(self) -> Dict[str, Any]:
Expand Down
12 changes: 11 additions & 1 deletion torchvision/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def read_video(
start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None,
pts_unit: str = "pts",
output_format: str = "THWC",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Reads a video from a file, returning both the video frames as well as
Expand All @@ -252,15 +253,20 @@ def read_video(
The end presentation time
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
either 'pts' or 'sec'. Defaults to 'pts'.
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".

Returns:
vframes (Tensor[T, H, W, C]): the `T` video frames
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_video)

output_format = output_format.upper()
if output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")

from torchvision import get_video_backend

if not os.path.exists(filename):
Expand Down Expand Up @@ -334,6 +340,10 @@ def read_video(
else:
aframes = torch.empty((1, 0), dtype=torch.float32)

if output_format == "TCHW":
# [T,H,W,C] --> [T,C,H,W]
vframes = vframes.permute(0, 3, 1, 2)

return vframes, aframes, info


Expand Down