Skip to content

Commit

Permalink
Device agnostic compute for polarization (#150)
Browse files Browse the repository at this point in the history
undefined
  • Loading branch information
ziw-liu authored Jan 6, 2024
1 parent 366ef1b commit 8f1f65c
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 38 deletions.
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch


def device_params():
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
if torch.backends.mps.is_available():
devices.append("mps")
return "device", devices


_DEVICE = device_params()
14 changes: 9 additions & 5 deletions tests/models/test_inplane_oriented_thick_pol3D.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from waveorder import stokes
from tests.conftest import _DEVICE
from waveorder.models import inplane_oriented_thick_pol3d


Expand All @@ -16,23 +16,27 @@ def test_calculate_transfer_function():
assert intensity_to_stokes_matrix.shape == (4, 5)


def test_apply_inverse_transfer_function():
input_shape = (5, 10, 5, 5)
czyx_data = torch.rand(input_shape)
@pytest.mark.parametrize(*_DEVICE)
@pytest.mark.parametrize("estimate_bg", [True, False])
def test_apply_inverse_transfer_function(device, estimate_bg):
input_shape = (5, 10, 100, 100)
czyx_data = torch.rand(input_shape, device=device)

intensity_to_stokes_matrix = (
inplane_oriented_thick_pol3d.calculate_transfer_function(
swing=0.1,
scheme="5-State",
)
).to(device)
)

results = inplane_oriented_thick_pol3d.apply_inverse_transfer_function(
czyx_data=czyx_data,
intensity_to_stokes_matrix=intensity_to_stokes_matrix,
remove_estimated_background=estimate_bg,
)

assert len(results) == 4

for result in results:
assert result.shape == input_shape[1:]
assert result.device.type == device
42 changes: 42 additions & 0 deletions tests/test_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
import torch

from tests.conftest import _DEVICE
from waveorder.correction import (
_fit_2d_polynomial_surface,
_grid_coordinates,
_sample_block_medians,
estimate_background,
)


def test_sample_block_medians():
image = torch.arange(4 * 5, dtype=torch.float).reshape(4, 5)
medians = _sample_block_medians(image, 2)
assert torch.allclose(
medians, torch.tensor([1, 3, 11, 13]).to(image.dtype)
)


def test_grid_coordinates():
image = torch.ones(15, 17)
coords = _grid_coordinates(image, 4)
assert coords.shape == (3 * 4, 2)


def test_fit_2d_polynomial_surface():
coords = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float)
values = torch.tensor([0, 1, 2, 3], dtype=torch.float)
surface = _fit_2d_polynomial_surface(coords, values, 1, (2, 2))
assert torch.allclose(surface, values.reshape(surface.shape), atol=1e-2)


@pytest.mark.parametrize("order", [1, 2, 3])
@pytest.mark.parametrize(*_DEVICE)
def test_estimate_background(order, device):
image = torch.rand(200, 200).to(device)
image[:100, :100] += 1
background = estimate_background(image, order=order, block_size=32)
assert 2.0 > background[50, 50] > 1.0
assert 1.5 > background[0, 100] > 0.5
assert 1.0 > background[150, 150] > 0.0
77 changes: 51 additions & 26 deletions tests/test_stokes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from waveorder import stokes

from .conftest import _DEVICE


def test_S2I_matrix():
S2I5 = stokes.calculate_stokes_to_intensity_matrix(0.1)
Expand Down Expand Up @@ -35,18 +37,26 @@ def test_I2S_matrix():
tt.assert_close(I, torch.eye(I.shape[0]))


def test_s12_to_orientation():
for orientation in torch.linspace(0, np.pi, 25)[:-1]: # skip endpoint
@pytest.mark.parametrize(*_DEVICE)
def test_s12_to_orientation(device):
for orientation in torch.linspace(0, np.pi, 25, device=device)[
:-1
]: # skip endpoint
orientation1 = stokes._s12_to_orientation(
np.sin(2 * orientation), -np.cos(2 * orientation)
torch.sin(2 * orientation), -torch.cos(2 * orientation)
)
tt.assert_close(orientation, orientation1)


def test_stokes_recon():
# NOTE: skip retardance = 0 and depolarization = 0 because orientationentation is not defined
for retardance in torch.arange(1e-3, 1, 0.1): # fractions of a wave
for orientation in torch.arange(0, np.pi, np.pi / 10): # radians
@pytest.mark.parametrize(*_DEVICE)
def test_stokes_recon(device):
# NOTE: skip retardance = 0 and depolarization = 0 because orientation is not defined
for retardance in torch.arange(
1e-3, 1, 0.1, device=device
): # fractions of a wave
for orientation in torch.arange(
0, np.pi, np.pi / 10, device=device
): # radians
for transmittance in [0.1, 10]:
# Test attenuating retarder (ar) functions
ar = (retardance, orientation, transmittance)
Expand All @@ -56,7 +66,9 @@ def test_stokes_recon():
tt.assert_close(torch.tensor(ar[i]), ar1[i])

# Test attenuating depolarizing retarder (adr) functions
for depolarization in torch.arange(1e-3, 1, 0.1):
for depolarization in torch.arange(
1e-3, 1, 0.1, device=device
):
adr = (
retardance,
orientation,
Expand Down Expand Up @@ -109,22 +121,24 @@ def test_mueller_from_stokes():
tt.assert_close(torch.linalg.inv(M2), M2.T)


def test_mmul():
M = torch.ones((3, 2, 1))
x = torch.ones((2, 1))

@pytest.mark.parametrize(*_DEVICE)
def test_mmul(device):
M = torch.ones((3, 2, 1), device=device)
x = torch.ones((2, 1), device=device)
y = stokes.mmul(M, x) # should pass

assert y.shape == (3, 1)
assert y.device.type == device
with pytest.raises(ValueError):
M2 = torch.ones((3, 4, 1))
y2 = stokes.mmul(M2, x)


def test_copying():
a = torch.tensor([1, 1])
b = torch.tensor([1, 1])
c = torch.tensor([1, 1])
d = torch.tensor([1, 1])
@pytest.mark.parametrize(*_DEVICE)
def test_copying(device):
a = torch.tensor([1, 1], device=device)
b = torch.tensor([1, 1], device=device)
c = torch.tensor([1, 1], device=device)
d = torch.tensor([1, 1], device=device)
s0, s1, s2, s3 = stokes.stokes_after_adr(a, b, c, d)
s0[0] = 2 # modify the output
assert c[0] == 1 # check that the input hasn't changed
Expand All @@ -134,14 +148,19 @@ def test_copying():
assert a[0] == 1


def test_orientation_offset():
@pytest.mark.parametrize(*_DEVICE)
def test_orientation_offset(device):
ori = torch.tensor(
[0, torch.pi / 4, torch.pi / 2, torch.pi - 0.01, torch.pi]
[0, torch.pi / 4, torch.pi / 2, torch.pi - 0.01, torch.pi],
device=device,
)

ff = stokes.apply_orientation_offset(ori, rotate=False, flip=False)
assert torch.allclose(
ff, torch.tensor([0, torch.pi / 4, torch.pi / 2, torch.pi - 0.01, 0])
ff,
torch.tensor(
[0, torch.pi / 4, torch.pi / 2, torch.pi - 0.01, 0], device=device
),
)

tf = stokes.apply_orientation_offset(ori, rotate=True, flip=False)
Expand All @@ -154,26 +173,32 @@ def test_orientation_offset():
0,
(torch.pi / 2) - 0.01,
torch.pi / 2,
]
],
device=device,
),
)

ft = stokes.apply_orientation_offset(ori, rotate=False, flip=True)
assert torch.allclose(
ft,
torch.tensor([0, 3 * torch.pi / 4, torch.pi / 2, 0.01, 0]),
torch.tensor(
[0, 3 * torch.pi / 4, torch.pi / 2, 0.01, 0], device=device
),
)

tt = stokes.apply_orientation_offset(ori, rotate=True, flip=True)
rotated_fliped = stokes.apply_orientation_offset(
ori, rotate=True, flip=True
)
assert torch.allclose(
tt,
rotated_fliped,
torch.tensor(
[
torch.pi / 2,
torch.pi / 4,
0,
(torch.pi / 2) + 0.01,
torch.pi / 2,
]
],
device=device,
),
)
107 changes: 107 additions & 0 deletions waveorder/correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Background correction methods"""

import torch
import torch.nn.functional as F
from torch import Tensor, Size


def _sample_block_medians(image: Tensor, block_size) -> Tensor:
"""
Sample densely tiled square blocks from a 2D image and return their medians.
Incomplete blocks (overhangs) will be ignored.
Parameters
----------
image : Tensor
2D image
block_size : int, optional
Width and height of the blocks
Returns
-------
Tensor
Median intensity values for each block, flattened
"""
if not image.dtype.is_floating_point:
image.to(torch.float)
blocks = F.unfold(image[None, None], block_size, stride=block_size)[0]
return blocks.median(0)[0]


def _grid_coordinates(image: Tensor, block_size: int) -> Tensor:
"""Build image coordinates from the center points of square blocks"""
coords = torch.meshgrid(
[
torch.arange(
0 + block_size / 2,
boundary - block_size / 2 + 1,
block_size,
device=image.device,
)
for boundary in image.shape
]
)
return torch.stack(coords, dim=-1).reshape(-1, 2)


def _fit_2d_polynomial_surface(
coords: Tensor, values: Tensor, order: int, surface_shape: Size
) -> Tensor:
"""Fit a 2D polynomial to a set of coordinates and their values,
and return the surface evaluated at every point."""
n_coeffs = int((order + 1) * (order + 2) / 2)
if n_coeffs >= len(values):
raise ValueError(
f"Cannot fit a {order} degree 2D polynomial "
f"with {len(values)} sampled values"
)
orders = torch.arange(order + 1, device=coords.device)
order_pairs = torch.stack(torch.meshgrid(orders, orders), -1)
order_pairs = order_pairs[order_pairs.sum(-1) <= order].reshape(-1, 2)
terms = torch.stack(
[coords[:, 0] ** i * coords[:, 1] ** j for i, j in order_pairs], -1
)
# use "gels" driver for precision and GPU consistency
coeffs = torch.linalg.lstsq(terms, values, driver="gels").solution
dense_coords = torch.meshgrid(
[
torch.arange(s, dtype=values.dtype, device=values.device)
for s in surface_shape
]
)
dense_terms = torch.stack(
[dense_coords[0] ** i * dense_coords[1] ** j for i, j in order_pairs],
-1,
)
return torch.matmul(dense_terms, coeffs)


def estimate_background(image: Tensor, order: int = 2, block_size: int = 32):
"""
Combine sampling and polynomial surface fit for background estimation.
To background correct an image, divide it by the background.
Parameters
----------
image : Tensor
2D image
order : int, optional
Order of polynomial, by default 2
block_size : int, optional
Width and height of the blocks, by default 32
Returns
-------
Tensor
Background image
"""
if image.ndim != 2:
raise ValueError(f"Image must be 2D, got shape {image.shape}")
height, width = image.shape
if block_size > width:
raise ValueError("Block size larger than image height")
if block_size > height:
raise ValueError("Block size larger than image width")
medians = _sample_block_medians(image, block_size)
coords = _grid_coordinates(image, block_size)
return _fit_2d_polynomial_surface(coords, medians, order, image.shape)
8 changes: 3 additions & 5 deletions waveorder/models/inplane_oriented_thick_pol3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch import Tensor

from waveorder import background_estimator, stokes, util
from waveorder import correction, stokes, util


def generate_test_phantom(yx_shape):
Expand Down Expand Up @@ -125,7 +125,6 @@ def apply_inverse_transfer_function(

# Apply an "Estimated" background correction
if remove_estimated_background:
estimator = background_estimator.BackgroundEstimator2D()
for stokes_index in range(background_corrected_stokes.shape[0]):
# Project to 2D
z_projection = torch.mean(
Expand All @@ -134,9 +133,8 @@ def apply_inverse_transfer_function(
# Estimate the background and subtract
background_corrected_stokes[
stokes_index
] -= estimator.get_background(
z_projection,
normalize=False,
] -= correction.estimate_background(
z_projection, order=2, block_size=32
)

# Project to 2D (typically for SNR reasons)
Expand Down
Loading

0 comments on commit 8f1f65c

Please sign in to comment.