Skip to content

Commit

Permalink
Remove (N, T, H, W, C) => (N, T, C, H, W) from presets (#6058)
Browse files Browse the repository at this point in the history
* Remove `(N, T, H, W, C) => (N, T, C, H, W)` conversion on presets

* Update docs.

* Fix the tests

* Use `output_format` for `read_video()`

* Use `output_format` for `Kinetics()`

* Adding input descriptions on presets
  • Loading branch information
datumbox authored May 23, 2022
1 parent 4c66813 commit 60ce5bf
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ Here is an example of how to use the pre-trained video classification models:
from torchvision.io.video import read_video
from torchvision.models.video import r3d_18, R3D_18_Weights
vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi")
vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi", output_format="TCHW")
vid = vid[:32] # optionally shorten duration
# Step 1: Initialize model with the best available weights
Expand Down
3 changes: 1 addition & 2 deletions gallery/plot_optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def plot(imgs, **imshow_kwargs):
# single model input.

from torchvision.io import read_video
frames, _, _ = read_video(str(video_path))
frames = frames.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
frames, _, _ = read_video(str(video_path), output_format="TCHW")

img1_batch = torch.stack([frames[100], frames[150]])
img2_batch = torch.stack([frames[101], frames[151]])
Expand Down
2 changes: 2 additions & 0 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def main(args):
"avi",
"mp4",
),
output_format="TCHW",
)
if args.cache_dataset:
print(f"Saving dataset_train to {cache_path}")
Expand Down Expand Up @@ -193,6 +194,7 @@ def main(args):
"avi",
"mp4",
),
output_format="TCHW",
)
if args.cache_dataset:
print(f"Saving dataset_test to {cache_path}")
Expand Down
2 changes: 1 addition & 1 deletion test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_transforms_jit(model_fn):
"input_shape": (1, 3, 520, 520),
},
"video": {
"input_shape": (1, 4, 112, 112, 3),
"input_shape": (1, 4, 3, 112, 112),
},
"optical_flow": {
"input_shape": (1, 3, 128, 128),
Expand Down
19 changes: 14 additions & 5 deletions torchvision/transforms/_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ def __repr__(self) -> str:
return self.__class__.__name__ + "()"

def describe(self) -> str:
return "The images are rescaled to ``[0.0, 1.0]``."
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
"The images are rescaled to ``[0.0, 1.0]``."
)


class ImageClassification(nn.Module):
Expand Down Expand Up @@ -70,6 +73,7 @@ def __repr__(self) -> str:

def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
Expand Down Expand Up @@ -99,7 +103,6 @@ def forward(self, vid: Tensor) -> Tensor:
vid = vid.unsqueeze(dim=0)
need_squeeze = True

vid = vid.permute(0, 1, 4, 2, 3) # (N, T, H, W, C) => (N, T, C, H, W)
N, T, C, H, W = vid.shape
vid = vid.view(-1, C, H, W)
vid = F.resize(vid, self.resize_size, interpolation=self.interpolation)
Expand All @@ -126,9 +129,11 @@ def __repr__(self) -> str:

def describe(self) -> str:
return (
f"The video frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
"Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. "
f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output "
"dimensions are permuted to ``(..., C, T, H, W)`` tensors."
)


Expand Down Expand Up @@ -167,6 +172,7 @@ def __repr__(self) -> str:

def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
f"``std={self.std}``."
Expand Down Expand Up @@ -196,4 +202,7 @@ def __repr__(self) -> str:
return self.__class__.__name__ + "()"

def describe(self) -> str:
return "The images are rescaled to ``[-1.0, 1.0]``."
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
"The images are rescaled to ``[-1.0, 1.0]``."
)

0 comments on commit 60ce5bf

Please sign in to comment.