Skip to content

Commit

Permalink
Wrap-safe transfer functions (#175)
Browse files Browse the repository at this point in the history
* helper functions

* fluorescence wrap safety

* 3d phase wrap safety

* fix axial nyquist bug

* 2d phase wrap safety

* fix interaction between padding and wrap safety
  • Loading branch information
talonchandler authored Nov 9, 2024
1 parent a7ac036 commit cbda81d
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 4 deletions.
48 changes: 46 additions & 2 deletions waveorder/models/isotropic_fluorescent_thick_3d.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Literal

import numpy as np
import torch
from torch import Tensor

from waveorder import optics, util
from waveorder import optics, sampling, util


def generate_test_phantom(
Expand All @@ -28,6 +29,49 @@ def calculate_transfer_function(
index_of_refraction_media,
numerical_aperture_detection,
):

transverse_nyquist = sampling.transverse_nyquist(
wavelength_emission,
numerical_aperture_detection, # ill = det for fluorescence
numerical_aperture_detection,
)
axial_nyquist = sampling.axial_nyquist(
wavelength_emission,
numerical_aperture_detection,
index_of_refraction_media,
)

yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
z_factor = int(np.ceil(z_pixel_size / axial_nyquist))

optical_transfer_function = _calculate_wrap_unsafe_transfer_function(
(
zyx_shape[0] * z_factor,
zyx_shape[1] * yx_factor,
zyx_shape[2] * yx_factor,
),
yx_pixel_size / yx_factor,
z_pixel_size / z_factor,
wavelength_emission,
z_padding,
index_of_refraction_media,
numerical_aperture_detection,
)
zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:]
return sampling.nd_fourier_central_cuboid(
optical_transfer_function, zyx_out_shape
)


def _calculate_wrap_unsafe_transfer_function(
zyx_shape,
yx_pixel_size,
z_pixel_size,
wavelength_emission,
z_padding,
index_of_refraction_media,
numerical_aperture_detection,
):
radial_frequencies = util.generate_radial_frequencies(
zyx_shape[1:], yx_pixel_size
)
Expand Down Expand Up @@ -97,7 +141,7 @@ def apply_transfer_function(
Returns
-------
Simulated data : torch.Tensor
"""
if (
zyx_object.shape[0] + 2 * z_padding
Expand Down
61 changes: 60 additions & 1 deletion waveorder/models/isotropic_thin_3d.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Literal, Tuple

import numpy as np
import torch
from torch import Tensor

from waveorder import optics, util
from waveorder import optics, sampling, util


def generate_test_phantom(
Expand Down Expand Up @@ -42,6 +43,64 @@ def calculate_transfer_function(
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=False,
):
transverse_nyquist = sampling.transverse_nyquist(
wavelength_illumination,
numerical_aperture_illumination,
numerical_aperture_detection,
)
yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))

absorption_2d_to_3d_transfer_function, phase_2d_to_3d_transfer_function = (
_calculate_wrap_unsafe_transfer_function(
(
yx_shape[0] * yx_factor,
yx_shape[1] * yx_factor,
),
yx_pixel_size / yx_factor,
z_position_list,
wavelength_illumination,
index_of_refraction_media,
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=invert_phase_contrast,
)
)

absorption_2d_to_3d_transfer_function_out = torch.zeros(
(len(z_position_list),) + tuple(yx_shape), dtype=torch.complex64
)
phase_2d_to_3d_transfer_function_out = torch.zeros(
(len(z_position_list),) + tuple(yx_shape), dtype=torch.complex64
)

for z in range(len(z_position_list)):
absorption_2d_to_3d_transfer_function_out[z] = (
sampling.nd_fourier_central_cuboid(
absorption_2d_to_3d_transfer_function[z], yx_shape
)
)
phase_2d_to_3d_transfer_function_out[z] = (
sampling.nd_fourier_central_cuboid(
phase_2d_to_3d_transfer_function[z], yx_shape
)
)

return (
absorption_2d_to_3d_transfer_function_out,
phase_2d_to_3d_transfer_function_out,
)


def _calculate_wrap_unsafe_transfer_function(
yx_shape,
yx_pixel_size,
z_position_list,
wavelength_illumination,
index_of_refraction_media,
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=False,
):
if invert_phase_contrast:
z_position_list = torch.flip(torch.tensor(z_position_list), dims=(0,))
Expand Down
56 changes: 55 additions & 1 deletion waveorder/models/phase_thick_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch import Tensor

from waveorder import optics, util
from waveorder import optics, sampling, util
from waveorder.models import isotropic_fluorescent_thick_3d


Expand Down Expand Up @@ -40,6 +40,60 @@ def calculate_transfer_function(
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=False,
):
transverse_nyquist = sampling.transverse_nyquist(
wavelength_illumination,
numerical_aperture_illumination,
numerical_aperture_detection,
)
axial_nyquist = sampling.axial_nyquist(
wavelength_illumination,
numerical_aperture_detection,
index_of_refraction_media,
)

yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
z_factor = int(np.ceil(z_pixel_size / axial_nyquist))

real_potential_transfer_function, imag_potential_transfer_function = (
_calculate_wrap_unsafe_transfer_function(
(
zyx_shape[0] * z_factor,
zyx_shape[1] * yx_factor,
zyx_shape[2] * yx_factor,
),
yx_pixel_size / yx_factor,
z_pixel_size / z_factor,
wavelength_illumination,
z_padding,
index_of_refraction_media,
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=invert_phase_contrast,
)
)

zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:]
return (
sampling.nd_fourier_central_cuboid(
real_potential_transfer_function, zyx_out_shape
),
sampling.nd_fourier_central_cuboid(
imag_potential_transfer_function, zyx_out_shape
),
)


def _calculate_wrap_unsafe_transfer_function(
zyx_shape,
yx_pixel_size,
z_pixel_size,
wavelength_illumination,
z_padding,
index_of_refraction_media,
numerical_aperture_illumination,
numerical_aperture_detection,
invert_phase_contrast=False,
):
radial_frequencies = util.generate_radial_frequencies(
zyx_shape[1:], yx_pixel_size
Expand Down
94 changes: 94 additions & 0 deletions waveorder/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np
import torch


def transverse_nyquist(
wavelength_emission,
numerical_aperture_illumination,
numerical_aperture_detection,
):
"""Transverse Nyquist sample spacing in `wavelength_emission` units.
For widefield label-free imaging, the transverse Nyquist sample spacing is
lambda / (2 * (NA_ill + NA_det)).
Perhaps surprisingly, the transverse Nyquist sample spacing for widefield
fluorescence is lambda / (4 * NA), which is equivalent to the above formula
when NA_ill = NA_det.
Parameters
----------
wavelength_emission : float
Output units match these units
numerical_aperture_illumination : float
For widefield fluorescence, set to numerical_aperture_detection
numerical_aperture_detection : float
Returns
-------
float
Transverse Nyquist sample spacing
"""
return wavelength_emission / (
2 * (numerical_aperture_detection + numerical_aperture_illumination)
)


def axial_nyquist(
wavelength_emission,
numerical_aperture_detection,
index_of_refraction_media,
):
"""Axial Nyquist sample spacing in `wavelength_emission` units.
For widefield microscopes, the axial Nyquist cutoff frequency is:
(n/lambda) - sqrt( (n/lambda)^2 - (NA_det/lambda)^2 ),
and the axial Nyquist sample spacing is 1 / (2 * cutoff_frequency).
Perhaps surprisingly, the axial Nyquist sample spacing is independent of
the illumination numerical aperture.
Parameters
----------
wavelength_emission : float
Output units match these units
numerical_aperture_detection : float
index_of_refraction_media: float
Returns
-------
float
Axial Nyquist sample spacing
"""
n_on_lambda = index_of_refraction_media / wavelength_emission
cutoff_frequency = n_on_lambda - np.sqrt(
n_on_lambda**2
- (numerical_aperture_detection / wavelength_emission) ** 2
)
return 1 / (2 * cutoff_frequency)


def nd_fourier_central_cuboid(source, target_shape):
"""Central cuboid of an N-D Fourier transform.
Parameters
----------
source : torch.Tensor
Source tensor
target_shape : tuple of int
Returns
-------
torch.Tensor
Center cuboid in Fourier space
"""
center_slices = tuple(
slice((s - o) // 2, (s - o) // 2 + o)
for s, o in zip(source.shape, target_shape)
)
return torch.fft.ifftshift(torch.fft.fftshift(source)[center_slices])

0 comments on commit cbda81d

Please sign in to comment.