Skip to content

Commit

Permalink
Random rotation (apache#16794)
Browse files Browse the repository at this point in the history
* Added function for image rotation (imrotate) using BilinearSampler

Added unit tests for imrotate

Added Rotate to transforms

Added RandomRotation to transforms
Added transforms tests

* Made rotations generic - any angle

* Added tests. Removed useless docstrings. Updated API.

Co-authored-by: Douglas Andrade  <[email protected]>

Co-authored-by: Douglas Coimbra de Andrade <[email protected]>
  • Loading branch information
2 people authored and anirudh2290 committed May 29, 2020
1 parent 7224cc2 commit eab87f5
Show file tree
Hide file tree
Showing 4 changed files with 354 additions and 0 deletions.
73 changes: 73 additions & 0 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"Image transforms."

import random
import numpy as np

from ...block import Block, HybridBlock
from ...nn import Sequential, HybridSequential
from .... import image
Expand Down Expand Up @@ -198,6 +200,77 @@ def hybrid_forward(self, F, x):
return F.image.normalize(x, self._mean, self._std)


class Rotate(Block):
"""Rotate the input image by a given angle. Keeps the original image shape.
Parameters
----------
rotation_degrees : float32
Desired rotation angle in degrees.
zoom_in : bool
Zoom in image so that no padding is present in final output.
zoom_out : bool
Zoom out image so that the entire original image is present in final output.
Inputs:
- **data**: input tensor with (C x H x W) or (N x C x H x W) shape.
Outputs:
- **out**: output tensor with (C x H x W) or (N x C x H x W) shape.
"""
def __init__(self, rotation_degrees, zoom_in=False, zoom_out=False):
super(Rotate, self).__init__()
self._args = (rotation_degrees, zoom_in, zoom_out)

def forward(self, x):
if x.dtype is not np.float32:
raise TypeError("This transformation only supports float32. "
"Consider calling it after ToTensor")
return image.imrotate(x, *self._args)


class RandomRotation(Block):
"""Random rotate the input image by a random angle.
Keeps the original image shape and aspect ratio.
Parameters
----------
angle_limits: tuple
Tuple of 2 elements containing the upper and lower limit
for rotation angles in degree.
zoom_in : bool
Zoom in image so that no padding is present in final output.
zoom_out : bool
Zoom out image so that the entire original image is present in final output.
rotate_with_proba : float32
Inputs:
- **data**: input tensor with (C x H x W) or (N x C x H x W) shape.
Outputs:
- **out**: output tensor with (C x H x W) or (N x C x H x W) shape.
"""
def __init__(self, angle_limits, zoom_in=False, zoom_out=False, rotate_with_proba=1.0):
super(RandomRotation, self).__init__()
lower, upper = angle_limits
if lower >= upper:
raise ValueError("`angle_limits` must be an ordered tuple")
if rotate_with_proba < 0 or rotate_with_proba > 1:
raise ValueError("Probability of rotating the image should be between 0 and 1")
self._args = (angle_limits, zoom_in, zoom_out)
self._rotate_with_proba = rotate_with_proba

def forward(self, x):
if np.random.random() > self._rotate_with_proba:
return x
if x.dtype is not np.float32:
raise TypeError("This transformation only supports float32. "
"Consider calling it after ToTensor")
return image.random_rotate(x, *self._args)


class RandomResizedCrop(Block):
"""Crop the input image with random scale and aspect ratio.
Expand Down
143 changes: 143 additions & 0 deletions python/mxnet/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
import logging
import json
import warnings

from numbers import Number

import numpy as np

from .. import numpy as _mx_np # pylint: disable=reimported


Expand Down Expand Up @@ -612,6 +616,145 @@ def random_size_crop(src, size, area, ratio, interp=2, **kwargs):
return center_crop(src, size, interp)


def imrotate(src, rotation_degrees, zoom_in=False, zoom_out=False):
"""Rotates the input image(s) of a specific rotation degree.
Parameters
----------
src : NDArray
Input image (format CHW) or batch of images (format NCHW),
in both case is required a float32 data type.
rotation_degrees: scalar or NDArray
Wanted rotation in degrees. In case of `src` being a single image
a scalar is needed, otherwise a mono-dimensional vector of angles
or a scalar.
zoom_in: bool
If True input image(s) will be zoomed in a way so that no padding
will be shown in the output result.
zoom_out: bool
If True input image(s) will be zoomed in a way so that the whole
original image will be contained in the output result.
Returns
-------
NDArray
An `NDArray` containing the rotated image(s).
"""
if zoom_in and zoom_out:
raise ValueError("`zoom_in` and `zoom_out` cannot be both True")
if src.dtype is not np.float32:
raise TypeError("Only `float32` images are supported by this function")
# handles the case in which a single image is passed to this function
expanded = False
if src.ndim == 3:
expanded = True
src = src.expand_dims(axis=0)
if not isinstance(rotation_degrees, Number):
raise TypeError("When a single image is passed the rotation angle is "
"required to be a scalar.")
elif src.ndim != 4:
raise ValueError("Only 3D and 4D are supported by this function")

# when a scalar is passed we wrap it into an array
if isinstance(rotation_degrees, Number):
rotation_degrees = nd.array([rotation_degrees] * len(src),
ctx=src.context)

if len(src) != len(rotation_degrees):
raise ValueError(
"The number of images must be equal to the number of rotation angles"
)

rotation_degrees = rotation_degrees.as_in_context(src.context)
rotation_rad = np.pi * rotation_degrees / 180
# reshape the rotations angle in order to be broadcasted
# over the `src` tensor
rotation_rad = rotation_rad.expand_dims(axis=1).expand_dims(axis=2)
_, _, h, w = src.shape

# Generate a grid centered at the center of the image
hscale = (float(h - 1) / 2)
wscale = (float(w - 1) / 2)
h_matrix = (
nd.repeat(nd.arange(h, ctx=src.context).astype('float32').reshape(h, 1), w, axis=1) - hscale
).expand_dims(axis=0)
w_matrix = (
nd.repeat(nd.arange(w, ctx=src.context).astype('float32').reshape(1, w), h, axis=0) - wscale
).expand_dims(axis=0)
# perform rotation on the grid
c_alpha = nd.cos(rotation_rad)
s_alpha = nd.sin(rotation_rad)
w_matrix_rot = w_matrix * c_alpha - h_matrix * s_alpha
h_matrix_rot = w_matrix * s_alpha + h_matrix * c_alpha
# NOTE: grid normalization must be performed after the rotation
# to keep the aspec ratio
w_matrix_rot = w_matrix_rot / wscale
h_matrix_rot = h_matrix_rot / hscale

h, w = nd.array([h], ctx=src.context), nd.array([w], ctx=src.context)
# compute the scale factor in case `zoom_in` or `zoom_out` are True
if zoom_in or zoom_out:
rho_corner = nd.sqrt(h * h + w * w)
ang_corner = nd.arctan(h / w)
corner1_x_pos = nd.abs(rho_corner * nd.cos(ang_corner + nd.abs(rotation_rad)))
corner1_y_pos = nd.abs(rho_corner * nd.sin(ang_corner + nd.abs(rotation_rad)))
corner2_x_pos = nd.abs(rho_corner * nd.cos(ang_corner - nd.abs(rotation_rad)))
corner2_y_pos = nd.abs(rho_corner * nd.sin(ang_corner - nd.abs(rotation_rad)))
max_x = nd.maximum(corner1_x_pos, corner2_x_pos)
max_y = nd.maximum(corner1_y_pos, corner2_y_pos)
if zoom_out:
scale_x = max_x / w
scale_y = max_y / h
globalscale = nd.maximum(scale_x, scale_y)
else:
scale_x = w / max_x
scale_y = h / max_y
globalscale = nd.minimum(scale_x, scale_y)
globalscale = globalscale.expand_dims(axis=3)
else:
globalscale = 1
grid = nd.concat(w_matrix_rot.expand_dims(axis=1),
h_matrix_rot.expand_dims(axis=1), dim=1)
grid = grid * globalscale
rot_img = nd.BilinearSampler(src, grid)
if expanded:
return rot_img[0]
return rot_img


def random_rotate(src, angle_limits, zoom_in=False, zoom_out=False):
"""Random rotates `src` by an angle included in angle limits.
Parameters
----------
src : NDArray
Input image (format CHW) or batch of images (format NCHW),
in both case is required a float32 data type.
angle_limits: tuple
Tuple of 2 elements containing the upper and lower limit
for rotation angles in degree.
zoom_in: bool
If True input image(s) will be zoomed in a way so that no padding
will be shown in the output result.
zoom_out: bool
If True input image(s) will be zoomed in a way so that the whole
original image will be contained in the output result.
Returns
-------
NDArray
An `NDArray` containing the rotated image(s).
"""
if src.ndim == 3:
rotation_degrees = np.random.uniform(*angle_limits)
else:
n = src.shape[0]
rotation_degrees = nd.array(np.random.uniform(
*angle_limits,
size=n
))
return imrotate(src, rotation_degrees,
zoom_in=zoom_in, zoom_out=zoom_out)


class Augmenter(object):
"""Image Augmenter base class"""
def __init__(self, **kwargs):
Expand Down
56 changes: 56 additions & 0 deletions tests/python/unittest/test_gluon_data_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,67 @@ def test_transformer():
transforms.RandomHue(0.1),
transforms.RandomLighting(0.1),
transforms.ToTensor(),
transforms.RandomRotation([-10., 10.]),
transforms.Normalize([0, 0, 0], [1, 1, 1])])

transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read()


@with_seed()
def test_rotate():
transformer = transforms.Rotate(10.)
assertRaises(TypeError, transformer, mx.nd.ones((3, 30, 60), dtype='uint8'))
single_image = mx.nd.ones((3, 30, 60), dtype='float32')
single_output = transformer(single_image)
assert same(single_output.shape, (3, 30, 60))
batch_image = mx.nd.ones((3, 3, 30, 60), dtype='float32')
batch_output = transformer(batch_image)
assert same(batch_output.shape, (3, 3, 30, 60))

input_image = nd.array([[[0., 0., 0.],
[0., 0., 1.],
[0., 0., 0.]]])
rotation_angles_expected_outs = [
(90., nd.array([[[0., 1., 0.],
[0., 0., 0.],
[0., 0., 0.]]])),
(180., nd.array([[[0., 0., 0.],
[1., 0., 0.],
[0., 0., 0.]]])),
(270., nd.array([[[0., 0., 0.],
[0., 0., 0.],
[0., 1., 0.]]])),
(360., nd.array([[[0., 0., 0.],
[0., 0., 1.],
[0., 0., 0.]]])),
]
for rot_angle, expected_result in rotation_angles_expected_outs:
transformer = transforms.Rotate(rot_angle)
ans = transformer(input_image)
print(ans, expected_result)
assert_almost_equal(ans, expected_result, atol=1e-6)


@with_seed()
def test_random_rotation():
# test exceptions for probability input outside of [0,1]
assertRaises(ValueError, transforms.RandomRotation, [-10, 10.], rotate_with_proba=1.1)
assertRaises(ValueError, transforms.RandomRotation, [-10, 10.], rotate_with_proba=-0.3)
# test `forward`
transformer = transforms.RandomRotation([-10, 10.])
assertRaises(TypeError, transformer, mx.nd.ones((3, 30, 60), dtype='uint8'))
single_image = mx.nd.ones((3, 30, 60), dtype='float32')
single_output = transformer(single_image)
assert same(single_output.shape, (3, 30, 60))
batch_image = mx.nd.ones((3, 3, 30, 60), dtype='float32')
batch_output = transformer(batch_image)
assert same(batch_output.shape, (3, 3, 30, 60))
# test identity (rotate_with_proba = 0)
transformer = transforms.RandomRotation([-100., 100.], rotate_with_proba=0.0)
data = mx.nd.random_normal(shape=(3, 30, 60))
assert_almost_equal(data, transformer(data))


@with_seed()
def test_random_transforms():
from mxnet.gluon.data.vision import transforms
Expand Down
Loading

0 comments on commit eab87f5

Please sign in to comment.