Skip to content

Commit

Permalink
Alignment added useful tilt/rotation fitting functions
Browse files Browse the repository at this point in the history
Plus a unit test

Signed-off-by: Nicola VIGANÒ <[email protected]>
  • Loading branch information
Obi-Wan committed May 29, 2024
1 parent dd9f30c commit 2fb58bc
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 1 deletion.
142 changes: 141 additions & 1 deletion corrct/alignment/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
and ESRF - The European Synchrotron, Grenoble, France
"""

from typing import Literal, Optional, Union, Sequence
from typing import Optional, Union, Sequence

import numpy as np
from numpy.polynomial import Polynomial
import scipy.ndimage as spimg
import scipy.optimize as spopt
from numpy.typing import ArrayLike, NDArray
from skimage.transform import warp_polar
from skimage.filters import window
from scipy.optimize import minimize
import matplotlib.pyplot as plt

NDArrayFloat = NDArray[np.floating]

Expand Down Expand Up @@ -239,6 +243,142 @@ def fit_shifts_zyx_xc(
return shifts_zyx


def fit_image_rotation_and_scale(
img_1_vu: NDArray, img_2_vu: NDArray, pad_mode: Union[str, None] = None, window_type: str = "hann", verbose: bool = False
) -> tuple[float, float]:
"""Fit the rotation and scaling of an image against a reference image. This works best for larger rotation angles.
Parameters
----------
img_1_vu : NDArray
Reference image
img_2_vu : NDArray
Rotated and scaled image
pad_mode : Union[str, None], optional
Padding mode, by default None
window_type : str, optional
Windowing type (to cud the high frequency aliasing), by default "hann"
verbose : bool, optional
Whether to give verbose output, by default False
Returns
-------
tuple[float, float]
The rotation (in degrees) and scale of the second image with respect to the first
Raises
------
ValueError
In case of mismatching shape of the two images.
"""
if img_1_vu.ndim != img_2_vu.ndim or np.any(np.array(img_1_vu.shape) != np.array(img_2_vu.shape)):
raise ValueError(
f"Image shapes should be identical, but instead got image #1: {img_1_vu.shape}, and image #2: {img_2_vu.shape}"
)

axes = (-2, -1)

img_shape = img_2_vu.shape
if pad_mode is not None:
pad_widths = [(s // 2,) for s in img_shape]
img_1_vu = np.pad(img_1_vu, pad_width=pad_widths, mode=pad_mode)
img_2_vu = np.pad(img_2_vu, pad_width=pad_widths, mode=pad_mode)
img_shape = img_2_vu.shape

img_win = window(window_type=window_type, shape=img_shape)

img_fft_1 = np.fft.fft2(img_1_vu * img_win, axes=axes)
img_fft_2 = np.fft.fft2(img_2_vu * img_win, axes=axes)

# abs removes the translation component
img_fft_1 = np.abs(np.fft.fftshift(img_fft_1, axes=axes))
img_fft_2 = np.abs(np.fft.fftshift(img_fft_2, axes=axes))

# transform to polar coordinates
img_center = [s - s // 2 for s in img_shape]
radius = min([s // 2 for s in img_shape])
img_fft_1_p = warp_polar(img_fft_1, center=img_center, scaling="log", radius=radius)
img_fft_2_p = warp_polar(img_fft_2, center=img_center, scaling="log", radius=radius)

# only use half of FFT
img_fft_1_p = img_fft_1_p[..., : img_fft_1_p.shape[0] // 2, :]
img_fft_2_p = img_fft_2_p[..., : img_fft_2_p.shape[0] // 2, :]

fft_polar_shifts_rs = fit_shifts_vu_xc(img_fft_1_p[:, None, 1:], img_fft_2_p[:, None, 1:])
tilt_pix = np.squeeze(fft_polar_shifts_rs[0])
tilt_deg = (180 / img_fft_2_p.shape[0]) * tilt_pix

klog = img_fft_2_p.shape[1] / np.log(radius)
scale = np.exp(np.squeeze(fft_polar_shifts_rs[1]) / klog)

if verbose:
print(f"Fitted image rotation: {tilt_deg:.6} (degrees) or {tilt_pix} (pixels), with scale factor: {scale:.6}")

return tilt_deg, scale


def fit_camera_tilt_angle(img_1: NDArray, img_2: NDArray, pad_u: bool = False, fit_l1: bool = True, verbose: bool = False):
"""
Estimate the camera tilt angle based on correlation peak values between two images.
Parameters
----------
img_1: NDArray
The first image.
img_2: NDArray
The second image.
pad_u: bool, optional
Enable zero padding. Default is False.
fit_l1: bool, optional
Perform L1 norm fitting if True. Default is True.
verbose: bool, optional
Enable verbose output. Default is False.
Returns
-------
tuple[float, float]
Tuple containing the estimated center of rotation offset (pixels) and camera tilt angle (degrees).
"""
fitted_shifts_h = fit_shifts_vu_xc(img_1, img_2, pad_u=pad_u)
fitted_cors = fitted_shifts_h / 2

# Computing tilt
img_shape = img_2.shape
half_img_size = (img_shape[-2] - 1) / 2
cc_v_coords = np.linspace(-half_img_size, half_img_size, img_shape[-2])
poly_slope = Polynomial.fit(cc_v_coords, fitted_cors, deg=1)
b, a = poly_slope.convert().coef

if fit_l1:

def f(coeffs: NDArray) -> float:
b, a = coeffs[0], coeffs[1]
pred_line = cc_v_coords * a + b
l1_diff = np.linalg.norm(pred_line - fitted_cors, ord=1)
return float(l1_diff)

coeffs_opt = minimize(f, np.array([b, a]))
b, a = coeffs_opt.x

tilt_deg = np.rad2deg(-a / 2)
cor_offset_pix = b

if verbose:
cor_trend = Polynomial([b, a])
print(f"Fitted center of rotation (pixels): {cor_offset_pix}, and camera tilt (degrees): {tilt_deg}")
fig, axs = plt.subplots(1, 1)
axs.scatter(cc_v_coords, fitted_cors, label="Line CoRs")
axs.plot(cc_v_coords, cor_trend(cc_v_coords), "-C1", label="Line CoRs trend")
axs.axhline(cor_offset_pix, color="C2", linestyle="--", label=f"Image CoR ({cor_offset_pix:.3})")
axs.set_title("Correlation peaks")
axs.grid()
axs.legend(fontsize=13)
fig.tight_layout()
plt.show(block=False)

return cor_offset_pix, tilt_deg


def sinusoid(
x: Union[NDArrayFloat, float], a: Union[NDArrayFloat, float], p: Union[NDArrayFloat, float], b: Union[NDArrayFloat, float]
) -> NDArrayFloat:
Expand Down
61 changes: 61 additions & 0 deletions tests/test_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from scipy.spatial.transform import Rotation
import pytest
import skimage.data as skd
import skimage.transform as skt
from numpy.typing import NDArray
import corrct as cct

Expand Down Expand Up @@ -162,6 +163,66 @@ def test_set_detector_tilt():
assert np.allclose(prj_geom_test.src_pos_xyz, expected_src), "Failed: Tilting the source"


def test_fit_image_rotation_and_scale():
"""
Test the fit_image_rotation_and_scale function from the cct.alignment.fitting module.
This function should return a tuple (angle, scale) that represents the rotation angle
and scaling factor required to align an input image with a reference image. The
function is tested using various combinations of identical images, simple rotations,
scaling, and both rotation and scaling. It is also tested with images of different
shapes, in which case it should raise a ValueError.
"""
# Load the camera image from scikit-image's data package
img: NDArray = skd.camera() / 255

# Test with two identical images should return 0 degrees and 1 scale
result = cct.alignment.fitting.fit_image_rotation_and_scale(img, img)
assert np.allclose(result, (0.0, 1.0))

# Test with a simple rotation of an image by 10 degrees
img_10 = skt.rotate(img, angle=10.0)
result = cct.alignment.fitting.fit_image_rotation_and_scale(img, img_10)
assert np.allclose(result, (10.0, 1), rtol=0.01)

# Test with a simple rotation of an image by 45 degrees
img_45 = skt.rotate(img, angle=45.0)
result = cct.alignment.fitting.fit_image_rotation_and_scale(img, img_45)
assert np.allclose(result, (45.0, 1), rtol=0.01)

# Test with a rotation of an image by 360 degrees
img_360 = np.copy(img)
result = cct.alignment.fitting.fit_image_rotation_and_scale(img, img_360)
assert np.allclose(result, (0.0, 1.0), rtol=0.01)

# Test with a scaling of an image by 10% using scikit-image's rescale function
img_scaled: NDArray = skt.rescale(img, scale=1.1)
center_y, center_x = img.shape[0] // 2, img.shape[1] // 2
img_scaled_cropped = img_scaled[
center_y - img.shape[0] // 2 : center_y + img.shape[0] // 2,
center_x - img.shape[1] // 2 : center_x + img.shape[1] // 2,
]
result = cct.alignment.fitting.fit_image_rotation_and_scale(img, img_scaled_cropped)
assert np.allclose(result, (0.0, 1.1), rtol=0.01, atol=0.05)

# Test with a rotation and scaling of an image
img_scaled = skt.rescale(img, scale=1.1)
img_scaled_cropped = img_scaled[
center_y - img.shape[0] // 2 : center_y + img.shape[0] // 2,
center_x - img.shape[1] // 2 : center_x + img.shape[1] // 2,
]
img_10_scaled = skt.rotate(img_scaled_cropped, angle=10.0)
result = cct.alignment.fitting.fit_image_rotation_and_scale(img, img_10_scaled)
assert np.allclose(result, (10.0, 1.1), rtol=0.01)

# Test with images of different shapes should raise a ValueError
try:
cct.alignment.fitting.fit_image_rotation_and_scale(img, np.random.rand(10, 20))
assert False, "Expected a ValueError"
except ValueError:
pass


@pytest.mark.skipif(not cct.projectors.astra_available, reason="astra-toolbox not available")
@pytest.mark.parametrize("add_noise", [(False,), (True,)])
def test_pre_alignment(add_noise: bool, theo_rot_axis: float = -1.25):
Expand Down

0 comments on commit 2fb58bc

Please sign in to comment.