Skip to content

Commit

Permalink
Add elastic transformation.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 464883355
  • Loading branch information
pabloduque0 authored and PIXDev committed Aug 26, 2022
1 parent 0af35d4 commit 07ece37
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 9 deletions.
2 changes: 2 additions & 0 deletions dm_pix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
adjust_hue = augment.adjust_hue
adjust_saturation = augment.adjust_saturation
affine_transform = augment.affine_transform
elastic_deformation = augment.elastic_deformation
flip_left_right = augment.flip_left_right
flip_up_down = augment.flip_up_down
gaussian_blur = augment.gaussian_blur
Expand Down Expand Up @@ -80,6 +81,7 @@
"adjust_saturation",
"affine_transform",
"depth_to_space",
"elastic_deformation",
"extract_patches",
"flat_nd_linear_interpolate",
"flat_nd_linear_interpolate_constant",
Expand Down
137 changes: 128 additions & 9 deletions dm_pix/_src/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""

import functools
from typing import Sequence, Tuple, Union
from typing import Callable, Sequence, Tuple, Union

import chex
from dm_pix._src import color_conversion
Expand Down Expand Up @@ -152,6 +152,91 @@ def adjust_saturation(
return jnp.stack(rgb_adjusted, axis=channel_axis)


def elastic_deformation(
key: chex.PRNGKey,
image: chex.Array,
alpha: chex.Numeric,
sigma: chex.Numeric,
*,
order: int = 1,
mode: str = "nearest",
cval: float = 0.,
channel_axis: int = -1,
) -> chex.Array:
"""Applies an elastic deformation to the given image.
Introduced by [Simard, 2003] and popularized by [Ronneberger, 2015]. Deforms
images by moving pixels locally around using displacement fields.
Small sigma values (< 1.) give pixelated images while higher values result
in water like results. Alpha should be in the between x5 and x10 the value
given for sigma for sensible resutls.
Args:
key: key: a JAX RNG key.
image: a JAX array representing an image. Assumes that the image is
either HWC or CHW.
alpha: strength of the distortion field. Higher values mean that pixels are
moved further with respect to the distortion field's direction.
sigma: standard deviation of the gaussian kernel used to smooth the
distortion fields.
order: the order of the spline interpolation, default is 1. The order has
to be in the range [0, 1]. Note that PIX interpolation will only be used
for order=1, for other values we use `jax.scipy.ndimage.map_coordinates`.
mode: the mode parameter determines how the input array is extended beyond
its boundaries. Default is 'nearest'. Modes 'nearest and 'constant' use
PIX interpolation, which is very fast on accelerators (especially on
TPUs). For all other modes, 'wrap', 'mirror' and 'reflect', we rely
on `jax.scipy.ndimage.map_coordinates`, which however is slow on
accelerators, so use it with care.
cval: value to fill past edges of input if mode is 'constant'. Default is
0.0.
channel_axis: the index of the channel axis.
Returns:
The transformed image.
"""
chex.assert_rank(image, 3)
if channel_axis != -1:
image = jnp.moveaxis(image, source=channel_axis, destination=-1)
single_channel_shape = (*image.shape[:-1], 1)
key_i, key_j = jax.random.split(key)
noise_i = jax.random.uniform(key_i, shape=single_channel_shape) * 2 - 1
noise_j = jax.random.uniform(key_j, shape=single_channel_shape) * 2 - 1

# 3 sigma on each side of the kernel's center covers ~99.7% of the
# probability mass.
shift_map_i = gaussian_blur(
image=noise_i,
sigma=sigma,
kernel_size=sigma * 3 * 2) * alpha
shift_map_j = gaussian_blur(
image=noise_j,
sigma=sigma,
kernel_size=sigma * 3 * 2) * alpha

meshgrid = jnp.meshgrid(*[jnp.arange(size) for size in single_channel_shape],
indexing="ij")
meshgrid[0] += shift_map_i
meshgrid[1] += shift_map_j

interpolate_function = _get_interpolate_function(
mode=mode,
order=order,
cval=cval,
)
transformed_image = jnp.concatenate([
interpolate_function(
image[..., channel, jnp.newaxis], jnp.asarray(meshgrid))
for channel in range(image.shape[-1])
], axis=-1)

if channel_axis != -1: # Set channel axis back to original index.
transformed_image = jnp.moveaxis(
transformed_image, source=-1, destination=channel_axis)
return transformed_image


def flip_left_right(
image: chex.Array,
*,
Expand Down Expand Up @@ -382,14 +467,11 @@ def affine_transform(
offset = jnp.full((3,), fill_value=offset)
coordinates += jnp.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1))

if mode == "nearest" and order == 1:
interpolate_function = interpolation.flat_nd_linear_interpolate
elif mode == "constant" and order == 1:
interpolate_function = functools.partial(
interpolation.flat_nd_linear_interpolate_constant, cval=cval)
else:
interpolate_function = functools.partial(
jax.scipy.ndimage.map_coordinates, mode=mode, order=order, cval=cval)
interpolate_function = _get_interpolate_function(
mode=mode,
order=order,
cval=cval,
)
return interpolate_function(image, coordinates)


Expand Down Expand Up @@ -603,3 +685,40 @@ def _depthwise_conv2d(
padding,
feature_group_count=inputs.shape[channel_axis],
dimension_numbers=dimension_numbers)


def _get_interpolate_function(
mode: str,
order: int,
cval: float = 0.,
) -> Callable[[chex.Array, chex.Array], chex.Array]:
"""Selects the interpolation function to use based on the given parameters.
PIX interpolations are preferred given they are faster on accelerators. For
the cases where such interpolation is not implemented by PIX we relly on
jax.scipy.ndimage.map_coordinates. See specifics below.
Args:
mode: the mode parameter determines how the input array is extended beyond
its boundaries. Modes 'nearest and 'constant' use PIX interpolation, which
is very fast on accelerators (especially on TPUs). For all other modes,
'wrap', 'mirror' and 'reflect', we rely on
`jax.scipy.ndimage.map_coordinates`, which however is slow on
accelerators, so use it with care.
order: the order of the spline interpolation. The order has to be in the
range [0, 1]. Note that PIX interpolation will only be used for order=1,
for other values we use `jax.scipy.ndimage.map_coordinates`.
cval: value to fill past edges of input if mode is 'constant'.
Returns:
The selected interpolation function.
"""
if mode == "nearest" and order == 1:
interpolate_function = interpolation.flat_nd_linear_interpolate
elif mode == "constant" and order == 1:
interpolate_function = functools.partial(
interpolation.flat_nd_linear_interpolate_constant, cval=cval)
else:
interpolate_function = functools.partial(
jax.scipy.ndimage.map_coordinates, mode=mode, order=order, cval=cval)
return interpolate_function
22 changes: 22 additions & 0 deletions dm_pix/_src/augment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,28 @@ def test_random_crop(self, images_list):
crop_fn = lambda img: augment.random_crop(key, img, (100, 100, 3))
self._test_fn(images_list, jax_fn=crop_fn, reference_fn=None)

@parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE),
("out_of_range", _RAND_FLOATS_OUT_OF_RANGE))
def test_elastic_deformation(self, images_list):
key = jax.random.PRNGKey(43)
elastic_deformation = functools.partial(
augment.elastic_deformation,
key,
sigma=5.,
alpha=10)
self._test_fn(images_list, jax_fn=elastic_deformation, reference_fn=None)

elastic_deformation = functools.partial(
augment.elastic_deformation,
key,
sigma=5.)
# Sigma has to be constant for jit since kernel_size is derived from it.
self._test_fn_with_random_arg(
images_list,
jax_fn=elastic_deformation,
reference_fn=None,
alpha=(40, 80))


class TestMatchReference(_ImageAugmentationTest):

Expand Down
6 changes: 6 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Augmentations
adjust_hue
adjust_saturation
affine_transform
elastic_transform
flip_left_right
flip_up_down
gaussian_blur
Expand Down Expand Up @@ -55,6 +56,11 @@ affine_transform

.. autofunction:: affine_transform

elastic_deformation
~~~~~~~~~~~~~~~~~

.. autofunction:: elastic_deformation

flip_left_right
~~~~~~~~~~~~~~~

Expand Down

0 comments on commit 07ece37

Please sign in to comment.