Skip to content

Commit

Permalink
Add general affine transform.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 451891634
  • Loading branch information
pabloduque0 authored and PIXDev committed Jul 8, 2022
1 parent 54b072b commit ea2980e
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 41 deletions.
2 changes: 2 additions & 0 deletions dm_pix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
adjust_gamma = augment.adjust_gamma
adjust_hue = augment.adjust_hue
adjust_saturation = augment.adjust_saturation
affine_transform = augment.affine_transform
flip_left_right = augment.flip_left_right
flip_up_down = augment.flip_up_down
gaussian_blur = augment.gaussian_blur
Expand Down Expand Up @@ -74,6 +75,7 @@
"adjust_gamma",
"adjust_hue",
"adjust_saturation",
"affine_transform",
"depth_to_space",
"extract_patches",
"flat_nd_linear_interpolate",
Expand Down
95 changes: 94 additions & 1 deletion dm_pix/_src/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
that of TensorFlow.
"""

from typing import Sequence, Tuple
import functools
from typing import Sequence, Tuple, Union

import chex
from dm_pix._src import color_conversion
from dm_pix._src import interpolation
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -297,6 +299,97 @@ def solarize(image: chex.Array, threshold: chex.Numeric) -> chex.Array:
return jnp.where(image < threshold, image, 1. - image)


def affine_transform(
image: chex.Array,
matrix: chex.Array,
*,
offset: Union[chex.Array, chex.Numeric] = 0.,
order: int = 1,
mode: str = "nearest",
cval: float = 0.0,
) -> chex.Array:
"""Applies an affine transformation given by matrix.
Given an output image pixel index vector o, the pixel value is determined from
the input image at position jnp.dot(matrix, o) + offset.
This does 'pull' (or 'backward') resampling, transforming the output space to
the input to locate data. Affine transformations are often described in the
'push' (or 'forward') direction, transforming input to output. If you have a
matrix for the 'push' transformation, use its inverse (jax.numpy.linalg.inv)
in this function.
Args:
image: a JAX array representing an image. Assumes that the image is
either HWC or CHW.
matrix: the inverse coordinate transformation matrix, mapping output
coordinates to input coordinates. If ndim is the number of dimensions of
input, the given matrix must have one of the following shapes:
- (ndim, ndim): the linear transformation matrix for each output
coordinate.
- (ndim,): assume that the 2-D transformation matrix is diagonal, with the
diagonal specified by the given value.
- (ndim + 1, ndim + 1): assume that the transformation is specified using
homogeneous coordinates [1]. In this case, any value passed to offset is
ignored.
- (ndim, ndim + 1): as above, but the bottom row of a homogeneous
transformation matrix is always [0, 0, 0, 1], and may be omitted.
offset: the offset into the array where the transform is applied. If a
float, offset is the same for each axis. If an array, offset should
contain one value for each axis.
order: the order of the spline interpolation, default is 1. The order has
to be in the range 0-1. Note that PIX interpolation will only be used for
order=1, for other values we use `jax.scipy.ndimage.map_coordinates`.
mode: the mode parameter determines how the input array is extended beyond
its boundaries. Default is "nearest", using PIX
`flat_nd_linear_interpolate` function, which is very fast on accelerators
(especially on TPUs). For all other modes, 'constant', 'wrap', 'mirror'
and 'reflect', we rely on `jax.scipy.ndimage.map_coordinates`, which
however is slow on accelerators, so use it with care.
cval: value to fill past edges of input if mode is 'constant'. Default is
0.0.
Returns:
The input image transformed by the given matrix.
"""
chex.assert_rank(image, 3)
chex.assert_rank(matrix, {1, 2})
chex.assert_rank(offset, {0, 1})

if matrix.ndim == 1:
matrix = jnp.diag(matrix)

if matrix.shape not in [(3, 3), (4, 4), (3, 4)]:
error_msg = (
"Expected matrix shape must be one of (ndim, ndim), (ndim,)"
"(ndim + 1, ndim + 1) or (ndim, ndim + 1) being ndim the image.ndim. "
f"The affine matrix provided has shape {matrix.shape}.")
raise ValueError(error_msg)

meshgrid = jnp.meshgrid(*[jnp.arange(size) for size in image.shape],
indexing="ij")
indices = jnp.concatenate(
[jnp.expand_dims(x, axis=-1) for x in meshgrid], axis=-1)

if matrix.shape == (4, 4) or matrix.shape == (3, 4):
offset = matrix[:image.ndim, image.ndim]
matrix = matrix[:image.ndim, :image.ndim]

coordinates = indices @ matrix.T
coordinates = jnp.moveaxis(coordinates, source=-1, destination=0)

# Alter coordinates to account for offset.
offset = jnp.full((3,), fill_value=offset)
coordinates += jnp.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1))

if mode == "nearest" and order == 1:
interpolate_function = interpolation.flat_nd_linear_interpolate
else:
interpolate_function = functools.partial(
jax.scipy.ndimage.map_coordinates, mode=mode, order=order, cval=cval)
return interpolate_function(image, coordinates)


def random_flip_left_right(
key: chex.PRNGKey,
image: chex.Array,
Expand Down
Loading

0 comments on commit ea2980e

Please sign in to comment.