From e261ab6b0f681360a356052249f1664020cc6fae Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Wed, 1 Jun 2022 01:47:38 -0700 Subject: [PATCH] [fbsync] Remove `(N, T, H, W, C) => (N, T, C, H, W)` from presets (#6058) Summary: * Remove `(N, T, H, W, C) => (N, T, C, H, W)` conversion on presets * Update docs. * Fix the tests * Use `output_format` for `read_video()` * Use `output_format` for `Kinetics()` * Adding input descriptions on presets Reviewed By: NicolasHug Differential Revision: D36760943 fbshipit-source-id: 316f98583f39cc29b9a40f9c7c479b565981f088 --- docs/source/models.rst | 2 +- gallery/plot_optical_flow.py | 3 +-- references/video_classification/train.py | 2 ++ test/test_extended_models.py | 2 +- torchvision/transforms/_presets.py | 19 ++++++++++++++----- 5 files changed, 19 insertions(+), 9 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index ea3c57bb62b..b549c25bf94 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -471,7 +471,7 @@ Here is an example of how to use the pre-trained video classification models: from torchvision.io.video import read_video from torchvision.models.video import r3d_18, R3D_18_Weights - vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi") + vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi", output_format="TCHW") vid = vid[:32] # optionally shorten duration # Step 1: Initialize model with the best available weights diff --git a/gallery/plot_optical_flow.py b/gallery/plot_optical_flow.py index 495422c1f3e..9e8d0006f1a 100644 --- a/gallery/plot_optical_flow.py +++ b/gallery/plot_optical_flow.py @@ -72,8 +72,7 @@ def plot(imgs, **imshow_kwargs): # single model input. from torchvision.io import read_video -frames, _, _ = read_video(str(video_path)) -frames = frames.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) +frames, _, _ = read_video(str(video_path), output_format="TCHW") img1_batch = torch.stack([frames[100], frames[150]]) img2_batch = torch.stack([frames[101], frames[151]]) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index c7ac9e8c133..e1df08cbe4a 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -157,6 +157,7 @@ def main(args): "avi", "mp4", ), + output_format="TCHW", ) if args.cache_dataset: print(f"Saving dataset_train to {cache_path}") @@ -193,6 +194,7 @@ def main(args): "avi", "mp4", ), + output_format="TCHW", ) if args.cache_dataset: print(f"Saving dataset_test to {cache_path}") diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 408a8c0514c..396e79c3f6d 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -180,7 +180,7 @@ def test_transforms_jit(model_fn): "input_shape": (1, 3, 520, 520), }, "video": { - "input_shape": (1, 4, 112, 112, 3), + "input_shape": (1, 4, 3, 112, 112), }, "optical_flow": { "input_shape": (1, 3, 128, 128), diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 765ae8ec3c4..e49912e0f00 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -29,7 +29,10 @@ def __repr__(self) -> str: return self.__class__.__name__ + "()" def describe(self) -> str: - return "The images are rescaled to ``[0.0, 1.0]``." + return ( + "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. " + "The images are rescaled to ``[0.0, 1.0]``." + ) class ImageClassification(nn.Module): @@ -70,6 +73,7 @@ def __repr__(self) -> str: def describe(self) -> str: return ( + "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. " f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, " f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to " f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``." @@ -99,7 +103,6 @@ def forward(self, vid: Tensor) -> Tensor: vid = vid.unsqueeze(dim=0) need_squeeze = True - vid = vid.permute(0, 1, 4, 2, 3) # (N, T, H, W, C) => (N, T, C, H, W) N, T, C, H, W = vid.shape vid = vid.view(-1, C, H, W) vid = F.resize(vid, self.resize_size, interpolation=self.interpolation) @@ -126,9 +129,11 @@ def __repr__(self) -> str: def describe(self) -> str: return ( - f"The video frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, " + "Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. " + f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, " f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to " - f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``." + f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output " + "dimensions are permuted to ``(..., C, T, H, W)`` tensors." ) @@ -167,6 +172,7 @@ def __repr__(self) -> str: def describe(self) -> str: return ( + "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. " f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. " f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and " f"``std={self.std}``." @@ -196,4 +202,7 @@ def __repr__(self) -> str: return self.__class__.__name__ + "()" def describe(self) -> str: - return "The images are rescaled to ``[-1.0, 1.0]``." + return ( + "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. " + "The images are rescaled to ``[-1.0, 1.0]``." + )