Skip to content

Commit

Permalink
Fix RAFT input dimension check
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jan 13, 2025
1 parent 06a925c commit 1e08009
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def forward(self, image1, image2, num_flow_updates: int = 12):
batch_size, _, h, w = image1.shape
if (h, w) != image2.shape[-2:]:
raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}")
if not (h % 8 == 0) and (w % 8 == 0):
if not ((h % 8 == 0) and (w % 8 == 0)):
raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)")

fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0))
Expand Down

0 comments on commit 1e08009

Please sign in to comment.