Skip to content

Commit

Permalink
[fbsync] [proto] Another attemp to rewrite RandomCrop (#6410)
Browse files Browse the repository at this point in the history
Summary:
* [proto] Another attemp to rewrite RandomCrop

* Fixed implementation issue and updated tests

Reviewed By: datumbox

Differential Revision: D38824238

fbshipit-source-id: ea8d8ffadba1e6e85920b6f937470fe671e3babc
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Aug 24, 2022
1 parent 5b1e1cc commit e8e42e1
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 30 deletions.
41 changes: 32 additions & 9 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,25 +623,49 @@ def test_assertions(self):
with pytest.raises(ValueError, match="Padding mode should be either"):
transforms.RandomCrop([10, 12], padding=1, padding_mode="abc")

def test__get_params(self, mocker):
@pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]])
@pytest.mark.parametrize("size, pad_if_needed", [((10, 10), False), ((50, 25), True)])
def test__get_params(self, padding, pad_if_needed, size, mocker):
image = mocker.MagicMock(spec=features.Image)
image.num_channels = 3
image.image_size = (24, 32)
h, w = image.image_size

transform = transforms.RandomCrop([10, 10])
transform = transforms.RandomCrop(size, padding=padding, pad_if_needed=pad_if_needed)
params = transform._get_params(image)

assert 0 <= params["top"] <= h - transform.size[0] + 1
assert 0 <= params["left"] <= w - transform.size[1] + 1
assert params["height"] == 10
assert params["width"] == 10
if padding is not None:
if isinstance(padding, int):
h += 2 * padding
w += 2 * padding
elif isinstance(padding, list) and len(padding) == 2:
w += 2 * padding[0]
h += 2 * padding[1]
elif isinstance(padding, list) and len(padding) == 4:
w += padding[0] + padding[2]
h += padding[1] + padding[3]

expected_input_width = w
expected_input_height = h

if pad_if_needed:
if w < size[1]:
w += 2 * (size[1] - w)
if h < size[0]:
h += 2 * (size[0] - h)

assert 0 <= params["top"] <= h - size[0] + 1
assert 0 <= params["left"] <= w - size[1] + 1
assert params["height"] == size[0]
assert params["width"] == size[1]
assert params["input_width"] == expected_input_width
assert params["input_height"] == expected_input_height

@pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]])
@pytest.mark.parametrize("pad_if_needed", [False, True])
@pytest.mark.parametrize("fill", [False, True])
@pytest.mark.parametrize("padding_mode", ["constant", "edge"])
def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker):
def test__transform(self, padding, pad_if_needed, fill, padding_mode, mocker):
output_size = [10, 12]
transform = transforms.RandomCrop(
output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode
Expand Down Expand Up @@ -671,13 +695,12 @@ def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker):
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
params = transform._get_params(inpt)
if padding is None and not pad_if_needed:
params = transform._get_params(inpt)
fn_crop.assert_called_once_with(
inpt, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1]
)
elif not pad_if_needed:
params = transform._get_params(expected)
fn_crop.assert_called_once_with(
expected, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1]
)
Expand Down
62 changes: 41 additions & 21 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor
from torchvision.transforms.functional_tensor import _parse_pad_padding
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size
from typing_extensions import Literal

Expand Down Expand Up @@ -448,43 +449,62 @@ def __init__(
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
_, height, width = get_image_dimensions(image)

if self.padding is not None:
# update height, width with static padding data
padding = self.padding
if isinstance(padding, Sequence):
padding = list(padding)
pad_left, pad_right, pad_top, pad_bottom = _parse_pad_padding(padding)
height += pad_top + pad_bottom
width += pad_left + pad_right

output_height, output_width = self.size
# We have to store maybe padded image size for pad_if_needed branch in _transform
input_height, input_width = height, width

if self.pad_if_needed:
# pad width if needed
if width < output_width:
width += 2 * (output_width - width)
# pad height if needed
if height < output_height:
height += 2 * (output_height - height)

if height + 1 < output_height or width + 1 < output_width:
raise ValueError(
f"Required crop size {(output_height, output_width)} is larger then input image size {(height, width)}"
)

if width == output_width and height == output_height:
return dict(top=0, left=0, height=height, width=width)
return dict(top=0, left=0, height=height, width=width, input_width=input_width, input_height=input_height)

top = torch.randint(0, height - output_height + 1, size=(1,)).item()
left = torch.randint(0, width - output_width + 1, size=(1,)).item()
return dict(top=top, left=left, height=output_height, width=output_width)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.crop(inpt, **params)

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
return dict(
top=top,
left=left,
height=output_height,
width=output_width,
input_width=input_width,
input_height=input_height,
)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if self.padding is not None:
sample = F.pad(sample, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode)

image = query_image(sample)
_, height, width = get_image_dimensions(image)
inpt = F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode)

if self.pad_if_needed:
# pad the width if needed
if width < self.size[1]:
padding = [self.size[1] - width, 0]
sample = F.pad(sample, padding=padding, fill=self.fill, padding_mode=self.padding_mode)
# pad the height if needed
if height < self.size[0]:
padding = [0, self.size[0] - height]
sample = F.pad(sample, padding=padding, fill=self.fill, padding_mode=self.padding_mode)

return super().forward(sample)
input_width, input_height = params["input_width"], params["input_height"]
if input_width < self.size[1]:
padding = [self.size[1] - input_width, 0]
inpt = F.pad(inpt, padding=padding, fill=self.fill, padding_mode=self.padding_mode)
if input_height < self.size[0]:
padding = [0, self.size[0] - input_height]
inpt = F.pad(inpt, padding=padding, fill=self.fill, padding_mode=self.padding_mode)

return F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])


class RandomPerspective(_RandomApplyTransform):
Expand Down

0 comments on commit e8e42e1

Please sign in to comment.