diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 178079349..1005a619d 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -3,28 +3,14 @@ _emd_hook = True from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction -from py4DSTEM.process.phase.iterative_mixedstate_ptychography import ( - MixedstatePtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_multislice_ptychography import ( - MultislicePtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import ( - OverlapMagneticTomographicReconstruction, -) -from py4DSTEM.process.phase.iterative_overlap_tomography import ( - OverlapTomographicReconstruction, -) +from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction +from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction +from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction +from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import OverlapMagneticTomographicReconstruction +from py4DSTEM.process.phase.iterative_overlap_tomography import OverlapTomographicReconstruction from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction -from py4DSTEM.process.phase.iterative_simultaneous_ptychography import ( - SimultaneousPtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_singleslice_ptychography import ( - SingleslicePtychographicReconstruction, -) -from py4DSTEM.process.phase.parameter_optimize import ( - OptimizationParameter, - PtychographyOptimizer, -) +from py4DSTEM.process.phase.iterative_simultaneous_ptychography import SimultaneousPtychographicReconstruction +from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction +from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer # fmt: on diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 6d7967550..04cfd6a60 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -56,6 +56,53 @@ def attach_datacube(self, datacube: DataCube): self._datacube = datacube return self + def reinitialize_parameters(self, device: str = None, verbose: bool = None): + """ + Reinitializes common parameters. This is useful when loading a previously-saved + reconstruction (which set device='cpu' and verbose=True for compatibility) , + using different initialization parameters. + + Parameters + ---------- + device: str, optional + If not None, imports and assigns appropriate device modules + verbose: bool, optional + If not None, sets the verbosity to verbose + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + + if device is not None: + if device == "cpu": + self._xp = np + self._asnumpy = np.asarray + from scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from scipy.special import erf + + self._erf = erf + elif device == "gpu": + self._xp = cp + self._asnumpy = cp.asnumpy + from cupyx.scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from cupyx.scipy.special import erf + + self._erf = erf + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self._device = device + + if verbose is not None: + self._verbose = verbose + + return self + def set_save_defaults( self, save_datacube: bool = False, @@ -278,7 +325,9 @@ def _extract_intensities_and_calibrations_from_datacube( """ # Copies intensities to device casting to float32 - intensities = datacube.data + xp = self._xp + + intensities = xp.asarray(datacube.data, dtype=xp.float32) self._grid_scan_shape = intensities.shape[:2] # Extracts calibrations @@ -295,13 +344,14 @@ def _extract_intensities_and_calibrations_from_datacube( if require_calibrations: raise ValueError("Real-space calibrations must be given in 'A'") - warnings.warn( - ( - "Iterative reconstruction will not be quantitative unless you specify " - "real-space calibrations in 'A'" - ), - UserWarning, - ) + if self._verbose: + warnings.warn( + ( + "Iterative reconstruction will not be quantitative unless you specify " + "real-space calibrations in 'A'" + ), + UserWarning, + ) self._scan_sampling = (1.0, 1.0) self._scan_units = ("pixels",) * 2 @@ -359,13 +409,14 @@ def _extract_intensities_and_calibrations_from_datacube( "Reciprocal-space calibrations must be given in in 'A^-1' or 'mrad'" ) - warnings.warn( - ( - "Iterative reconstruction will not be quantitative unless you specify " - "appropriate reciprocal-space calibrations" - ), - UserWarning, - ) + if self._verbose: + warnings.warn( + ( + "Iterative reconstruction will not be quantitative unless you specify " + "appropriate reciprocal-space calibrations" + ), + UserWarning, + ) self._angular_sampling = (1.0, 1.0) self._angular_units = ("pixels",) * 2 @@ -448,8 +499,6 @@ def _calculate_intensities_center_of_mass( xp = self._xp asnumpy = self._asnumpy - intensities = xp.asarray(intensities, dtype=xp.float32) - # for ptycho if com_measured: com_measured_x, com_measured_y = com_measured @@ -484,9 +533,14 @@ def _calculate_intensities_center_of_mass( ) if com_shifts is None: + com_measured_x_np = asnumpy(com_measured_x) + com_measured_y_np = asnumpy(com_measured_y) + finite_mask = np.isfinite(com_measured_x_np) + com_shifts = fit_origin( - (asnumpy(com_measured_x), asnumpy(com_measured_y)), + (com_measured_x_np, com_measured_y_np), fitfunction=fit_function, + mask=finite_mask, ) # Fit function to center of mass @@ -494,12 +548,12 @@ def _calculate_intensities_center_of_mass( com_fitted_y = xp.asarray(com_shifts[1], dtype=xp.float32) # fix CoM units - com_normalized_x = (com_measured_x - com_fitted_x) * self._reciprocal_sampling[ - 0 - ] - com_normalized_y = (com_measured_y - com_fitted_y) * self._reciprocal_sampling[ - 1 - ] + com_normalized_x = ( + xp.nan_to_num(com_measured_x - com_fitted_x) * self._reciprocal_sampling[0] + ) + com_normalized_y = ( + xp.nan_to_num(com_measured_y - com_fitted_y) * self._reciprocal_sampling[1] + ) return ( com_measured_x, @@ -1077,6 +1131,7 @@ def _normalize_diffraction_intensities( diffraction_intensities, com_fitted_x, com_fitted_y, + crop_patterns, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1089,6 +1144,9 @@ def _normalize_diffraction_intensities( Best fit horizontal center of mass gradient com_fitted_y: (Rx,Ry) xp.ndarray Best fit vertical center of mass gradient + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns + when centering Returns ------- @@ -1101,13 +1159,46 @@ def _normalize_diffraction_intensities( xp = self._xp mean_intensity = 0 - amplitudes = xp.zeros_like(diffraction_intensities) - region_of_interest_shape = diffraction_intensities.shape[-2:] + diffraction_intensities = self._asnumpy(diffraction_intensities) + if crop_patterns: + crop_x = int( + np.minimum( + diffraction_intensities.shape[2] - com_fitted_x.max(), + com_fitted_x.min(), + ) + ) + crop_y = int( + np.minimum( + diffraction_intensities.shape[3] - com_fitted_y.max(), + com_fitted_y.min(), + ) + ) + + crop_w = np.minimum(crop_y, crop_x) + region_of_interest_shape = (crop_w * 2, crop_w * 2) + amplitudes = np.zeros( + ( + diffraction_intensities.shape[0], + diffraction_intensities.shape[1], + crop_w * 2, + crop_w * 2, + ), + dtype=np.float32, + ) + + crop_mask = np.zeros(diffraction_intensities.shape[-2:], dtype=np.bool_) + crop_mask[:crop_w, :crop_w] = True + crop_mask[-crop_w:, :crop_w] = True + crop_mask[:crop_w:, -crop_w:] = True + crop_mask[-crop_w:, -crop_w:] = True + self._crop_mask = crop_mask + + else: + region_of_interest_shape = diffraction_intensities.shape[-2:] + amplitudes = np.zeros(diffraction_intensities.shape, dtype=np.float32) com_fitted_x = self._asnumpy(com_fitted_x) com_fitted_y = self._asnumpy(com_fitted_y) - diffraction_intensities = self._asnumpy(diffraction_intensities) - amplitudes = self._asnumpy(amplitudes) for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): @@ -1119,16 +1210,71 @@ def _normalize_diffraction_intensities( device="cpu", ) + if crop_patterns: + intensities = intensities[crop_mask].reshape( + region_of_interest_shape + ) + mean_intensity += np.sum(intensities) amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0)) - amplitudes = xp.asarray(amplitudes, dtype=xp.float32) - amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape) + amplitudes = xp.asarray(amplitudes) mean_intensity /= amplitudes.shape[0] return amplitudes, mean_intensity + def show_complex_CoM( + self, + com=None, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, + ): + """ + Plot complex-valued CoM image + + Parameters + ---------- + + com = (CoM_x, CoM_y) tuple + If None is specified, uses (self.com_x, self.com_y) instead + cbar: bool, optional + if True, adds colorbar + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A + pixelsize: float, optional + default is scan sampling + """ + + if com is None: + com = (self.com_x, self.com_y) + + if pixelsize is None: + pixelsize = self._scan_sampling[0] + if pixelunits is None: + pixelunits = r"$\AA$" + + figsize = kwargs.pop("figsize", (6, 6)) + fig, ax = plt.subplots(figsize=figsize) + + complex_com = com[0] + 1j * com[1] + + show_complex( + complex_com, + cbar=cbar, + figax=(fig, ax), + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=False, + **kwargs, + ) + class PtychographicReconstruction(PhaseReconstruction, PtychographicConstraints): """ @@ -1309,10 +1455,10 @@ def _get_constructor_args(cls, group): "object_type": instance_md["object_type"], "semiangle_cutoff": instance_md["semiangle_cutoff"], "rolloff": instance_md["rolloff"], - "verbose": instance_md["verbose"], "name": instance_md["name"], - "device": instance_md["device"], "polar_parameters": polar_params, + "verbose": True, # for compatibility + "device": "cpu", # for compatibility } class_specific_kwargs = {} @@ -2109,6 +2255,7 @@ def show_fourier_probe( pixelunits = r"$\AA^{-1}$" figsize = kwargs.pop("figsize", (6, 6)) + chroma_boost = kwargs.pop("chroma_boost", 2) fig, ax = plt.subplots(figsize=figsize) show_complex( @@ -2119,6 +2266,7 @@ def show_fourier_probe( pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, + chroma_boost=chroma_boost, **kwargs, ) @@ -2142,7 +2290,7 @@ def show_object_fft(self, obj=None, **kwargs): vmax = kwargs.pop("vmax", 1) power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/iterative_dpc.py index 02138d738..af3cbbb45 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/iterative_dpc.py @@ -195,9 +195,9 @@ def _get_constructor_args(cls, group): "datacube": dc, "initial_object_guess": np.asarray(obj), "energy": instance_md["energy"], - "verbose": instance_md["verbose"], "name": instance_md["name"], - "device": instance_md["device"], + "verbose": True, # for compatibility + "device": "cpu", # for compatibility } return kwargs @@ -718,24 +718,26 @@ def reconstruct( xp = self._xp asnumpy = self._asnumpy - if reset is None and hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - # Restart if store_iterations and (not hasattr(self, "object_phase_iterations") or reset): self.object_phase_iterations = [] - self.error_iterations = [] if reset: self.error = np.inf + self.error_iterations = [] self._step_size = step_size if step_size is not None else 0.5 self._padded_object_phase = self._padded_object_phase_initial.copy() + elif reset is None: + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + else: + self.error_iterations = [] self.error = getattr(self, "error", np.inf) @@ -770,7 +772,8 @@ def reconstruct( if (new_error > self.error) and backtrack: self._padded_object_phase = previous_iteration self._step_size /= 2 - print(f"Iteration {a0}, step reduced to {self._step_size}") + if self._verbose: + print(f"Iteration {a0}, step reduced to {self._step_size}") continue self.error = new_error @@ -807,10 +810,11 @@ def reconstruct( self.error_iterations.append(self.error.item()) if self._step_size < stopping_criterion: - warnings.warn( - f"Step-size has decreased below stopping criterion {stopping_criterion}.", - UserWarning, - ) + if self._verbose: + warnings.warn( + f"Step-size has decreased below stopping criterion {stopping_criterion}.", + UserWarning, + ) # crop result self._object_phase = self._padded_object_phase[ @@ -840,7 +844,7 @@ def _visualize_last_iteration( If true, the NMSE error plot is displayed """ - figsize = kwargs.pop("figsize", (8, 8)) + figsize = kwargs.pop("figsize", (5, 6)) cmap = kwargs.pop("cmap", "magma") if plot_convergence: @@ -862,7 +866,7 @@ def _visualize_last_iteration( im = ax1.imshow(self.object_phase, extent=extent, cmap=cmap, **kwargs) ax1.set_ylabel(f"x [{self._scan_units[0]}]") ax1.set_xlabel(f"y [{self._scan_units[1]}]") - ax1.set_title(f"DPC Phase Reconstruction - NMSE error: {self.error:.3e}") + ax1.set_title(f"DPC phase reconstruction - NMSE error: {self.error:.3e}") if cbar: divider = make_axes_locatable(ax1) @@ -870,11 +874,11 @@ def _visualize_last_iteration( fig.add_axes(ax_cb) fig.colorbar(im, cax=ax_cb) - if plot_convergence and hasattr(self, "_error_iterations"): - errors = self._error_iterations + if plot_convergence: + errors = self.error_iterations ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() @@ -979,7 +983,7 @@ def _visualize_all_iterations( if plot_convergence: ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() @@ -990,7 +994,7 @@ def visualize( fig=None, iterations_grid: Tuple[int, int] = None, plot_convergence: bool = True, - cbar: bool = False, + cbar: bool = True, **kwargs, ): """ diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py new file mode 100644 index 000000000..3eeb07814 --- /dev/null +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -0,0 +1,3511 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely multislice ptychography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import pylops +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex + +try: + import cupy as cp +except ImportError: + cp = None + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, + spatial_frequencies, +) +from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar +from scipy.ndimage import rotate + +warnings.simplefilter(action="always", category=UserWarning) + + +class MixedstateMultislicePtychographicReconstruction(PtychographicReconstruction): + """ + Mixed-State Multislice Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (N,Sx,Sy) + Reconstructed object dimensions : (T,Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes + and (Px,Py) is the padded-object size we position our ROI around in + each of the T slices. + + Parameters + ---------- + energy: float + The electron energy of the wave functions in eV + num_probes: int, optional + Number of mixed-state probes + num_slices: int + Number of slices to use in the forward model + slice_thicknesses: float or Sequence[float] + Slice thicknesses in angstroms. If float, all slices are assigned the same thickness + datacube: DataCube, optional + Input 4D diffraction pattern intensities + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py) + If None, initialized to 1.0j + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ("_num_probes", "_num_slices", "_slice_thicknesses") + + def __init__( + self, + energy: float, + num_slices: int, + slice_thicknesses: Union[float, Sequence[float]], + num_probes: int = None, + datacube: DataCube = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + object_type: str = "complex", + verbose: bool = True, + device: str = "cpu", + name: str = "multi-slice_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): + if num_probes is None: + raise ValueError( + ( + "If initial_probe_guess is None, or a ComplexProbe object, " + "num_probes must be specified." + ) + ) + else: + if len(initial_probe_guess.shape) != 3: + raise ValueError( + "Specified initial_probe_guess must have dimensions (N,Sx,Sy)." + ) + num_probes = initial_probe_guess.shape[0] + + if device == "cpu": + self._xp = np + self._asnumpy = np.asarray + from scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from scipy.special import erf + + self._erf = erf + elif device == "gpu": + self._xp = cp + self._asnumpy = cp.asnumpy + from cupyx.scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from cupyx.scipy.special import erf + + self._erf = erf + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + slice_thicknesses = np.array(slice_thicknesses) + if slice_thicknesses.shape == (): + slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) + elif slice_thicknesses.shape[0] != (num_slices - 1): + raise ValueError( + ( + f"slice_thicknesses must have length {num_slices - 1}, " + f"not {slice_thicknesses.shape[0]}." + ) + ) + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._verbose = verbose + self._device = device + self._preprocessed = False + + # Class-specific Metadata + self._num_probes = num_probes + self._num_slices = num_slices + self._slice_thicknesses = slice_thicknesses + + def _precompute_propagator_arrays( + self, + gpts: Tuple[int, int], + sampling: Tuple[float, float], + energy: float, + slice_thicknesses: Sequence[float], + ): + """ + Precomputes propagator arrays complex wave-function will be convolved by, + for all slice thicknesses. + + Parameters + ---------- + gpts: Tuple[int,int] + Wavefunction pixel dimensions + sampling: Tuple[float,float] + Wavefunction sampling in A + energy: float + The electron energy of the wave functions in eV + slice_thicknesses: Sequence[float] + Array of slice thicknesses in A + + Returns + ------- + propagator_arrays: np.ndarray + (T,Sx,Sy) shape array storing propagator arrays + """ + xp = self._xp + + # Frequencies + kx, ky = spatial_frequencies(gpts, sampling) + kx = xp.asarray(kx, dtype=xp.float32) + ky = xp.asarray(ky, dtype=xp.float32) + + # Propagators + wavelength = electron_wavelength_angstrom(energy) + num_slices = slice_thicknesses.shape[0] + propagators = xp.empty( + (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 + ) + for i, dz in enumerate(slice_thicknesses): + propagators[i] = xp.exp( + 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) + ) + propagators[i] *= xp.exp( + 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) + ) + + return propagators + + def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): + """ + Propagates array by Fourier convolving array with propagator_array. + + Parameters + ---------- + array: np.ndarray + Wavefunction array to be convolved + propagator_array: np.ndarray + Propagator array to convolve array with + + Returns + ------- + propagated_array: np.ndarray + Fourier-convolved array + """ + xp = self._xp + + return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "fourier", + probe_roi_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_center_of_mass: str = "default", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: float = None, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (T,Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + probe_roi_shape, (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_center_of_mass: str, optional + If 'default', the corrected CoM arrays will be displayed + If 'all', the computed and fitted CoM arrays will be displayed + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + + Returns + -------- + self: MixedstateMultislicePtychographicReconstruction + Self to accommodate chaining + """ + xp = self._xp + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._probe_roi_shape = probe_roi_shape + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + ( + self._datacube, + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts, + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube, + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + probe_roi_shape=self._probe_roi_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts, + ) + + self._intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube, + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + ( + self._com_measured_x, + self._com_measured_y, + self._com_fitted_x, + self._com_fitted_y, + self._com_normalized_x, + self._com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + self._intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts, + ) + + ( + self._rotation_best_rad, + self._rotation_best_transpose, + self._com_x, + self._com_y, + self.com_x, + self.com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + self._com_measured_x, + self._com_measured_y, + self._com_normalized_x, + self._com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=plot_center_of_mass, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + + ( + self._amplitudes, + self._mean_diffraction_intensity, + ) = self._normalize_diffraction_intensities( + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns + ) + + # explicitly delete namespace + self._num_diffraction_patterns = self._amplitudes.shape[0] + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + del self._intensities + + self._positions_px = self._calculate_scan_positions_in_pixels( + self._scan_positions + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # Object Initialization + if self._object is None: + pad_x = self._object_padding_px[0][1] + pad_y = self._object_padding_px[1][1] + p, q = np.round(np.max(self._positions_px, axis=0)) + p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( + "int" + ) + if self._object_type == "potential": + self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32) + elif self._object_type == "complex": + self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64) + else: + if self._object_type == "potential": + self._object = xp.asarray(self._object, dtype=xp.float32) + elif self._object_type == "complex": + self._object = xp.asarray(self._object, dtype=xp.complex64) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) + self._positions_px_com = xp.mean(self._positions_px, axis=0) + self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 + self._positions_px_com = xp.mean(self._positions_px, axis=0) + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + self._positions_px_initial = self._positions_px.copy() + self._positions_initial = self._positions_px_initial.copy() + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # Vectorized Patches + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Probe Initialization + if self._probe is None or isinstance(self._probe, ComplexProbe): + if self._probe is None: + if self._vacuum_probe_intensity is not None: + self._semiangle_cutoff = np.inf + self._vacuum_probe_intensity = xp.asarray( + self._vacuum_probe_intensity, dtype=xp.float32 + ) + probe_x0, probe_y0 = get_CoM( + self._vacuum_probe_intensity, + device=self._device, + ) + self._vacuum_probe_intensity = get_shifted_ar( + self._vacuum_probe_intensity, + -probe_x0, + -probe_y0, + bilinear=True, + device=self._device, + ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) + _probe = ( + ComplexProbe( + gpts=self._region_of_interest_shape, + sampling=self.sampling, + energy=self._energy, + semiangle_cutoff=self._semiangle_cutoff, + rolloff=self._rolloff, + vacuum_probe_intensity=self._vacuum_probe_intensity, + parameters=self._polar_parameters, + device=self._device, + ) + .build() + ._array + ) + + else: + if self._probe._gpts != self._region_of_interest_shape: + raise ValueError() + if hasattr(self._probe, "_array"): + _probe = self._probe._array + else: + self._probe._xp = xp + _probe = self._probe.build()._array + + self._probe = xp.zeros( + (self._num_probes,) + tuple(self._region_of_interest_shape), + dtype=xp.complex64, + ) + sx, sy = self._region_of_interest_shape + self._probe[0] = _probe + + # Randomly shift phase of other probes + for i_probe in range(1, self._num_probes): + shift_x = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) + ) + shift_y = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) + ) + self._probe[i_probe] = ( + self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None] + ) + + # Normalize probe to match mean diffraction intensity + probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2) + self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) + + else: + self._probe = xp.asarray(self._probe, dtype=xp.complex64) + + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = None # Doesn't really make sense for mixed-state + + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + # Precomputed propagator arrays + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + ) + + # overlaps + shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) + probe_intensities = xp.abs(shifted_probes) ** 2 + probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) + probe_overlap = self._gaussian_filter(probe_overlap, 1.0) + + if object_fov_mask is None: + self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + if plot_probe_overlaps: + figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=2, + chroma_boost=chroma_boost, + ) + + # propagated + propagated_probe = self._probe[0].copy() + + for s in range(self._num_slices - 1): + propagated_probe = self._propagate_array( + propagated_probe, self._propagator_arrays[s] + ) + complex_propagated_rgb = Complex2RGB( + asnumpy(self._return_centered_probe(propagated_probe)), + power=2, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax1, + chroma_boost=chroma_boost, + ) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe[0] intensity") + + ax2.imshow( + complex_propagated_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax2) + cax2 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax2, chroma_boost=chroma_boost) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_title("Propagated probe[0] intensity") + + ax3.imshow( + asnumpy(probe_overlap), + extent=extent, + cmap="Greys_r", + ) + ax3.scatter( + self.positions[:, 1], + self.positions[:, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax3.set_ylabel("x [A]") + ax3.set_xlabel("y [A]") + ax3.set_xlim((extent[0], extent[1])) + ax3.set_ylim((extent[2], extent[3])) + ax3.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + return self + + def _overlap_projection(self, current_object, current_probe): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + """ + + xp = self._xp + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + object_patches = complex_object[ + :, + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ] + + num_probe_positions = object_patches.shape[1] + + propagated_shape = ( + self._num_slices, + num_probe_positions, + self._num_probes, + self._region_of_interest_shape[0], + self._region_of_interest_shape[1], + ) + propagated_probes = xp.empty(propagated_shape, dtype=object_patches.dtype) + propagated_probes[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes = ( + xp.expand_dims(object_patches[s], axis=1) * propagated_probes[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes[s + 1] = self._propagate_array( + transmitted_probes, self._propagator_arrays[s] + ) + + return propagated_probes, object_patches, transmitted_probes + + def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): + """ + Ptychographic fourier projection method for GD method. + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + + Returns + -------- + exit_waves:np.ndarray + Exit wave difference + error: float + Reconstruction error + """ + + xp = self._xp + fourier_exit_waves = xp.fft.fft2(transmitted_probes) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) + error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + + intensity_norm[intensity_norm == 0.0] = np.inf + amplitude_modification = amplitudes / intensity_norm + + fourier_modified_overlap = amplitude_modification[:, None] * fourier_exit_waves + modified_exit_wave = xp.fft.ifft2(fourier_modified_overlap) + + exit_waves = modified_exit_wave - transmitted_probes + + return exit_waves, error + + def _projection_sets_fourier_projection( + self, + amplitudes, + transmitted_probes, + exit_waves, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic fourier projection method for DM_AP and RAAR methods. + Generalized projection using three parameters: a,b,c + + DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha + DM: DM_AP(1.0), AP: DM_AP(0.0) + + RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 + DM : RAAR(1.0) + + RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 + DM: RRR(1.0) + + SUPERFLIP : a = 0, b = 1, c = 2 + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + exit_waves: np.ndarray + previously estimated exit waves + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + exit_waves:np.ndarray + Updated exit wave difference + error: float + Reconstruction error + """ + + xp = self._xp + projection_x = 1 - projection_a - projection_b + projection_y = 1 - projection_c + + if exit_waves is None: + exit_waves = transmitted_probes.copy() + + fourier_exit_waves = xp.fft.fft2(transmitted_probes) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) + error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + + factor_to_be_projected = ( + projection_c * transmitted_probes + projection_y * exit_waves + ) + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + + intensity_norm_projected = xp.sqrt( + xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) + ) + intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf + + amplitude_modification = amplitudes / intensity_norm_projected + fourier_projected_factor *= amplitude_modification[:, None] + + projected_factor = xp.fft.ifft2(fourier_projected_factor) + + exit_waves = ( + projection_x * exit_waves + + projection_a * transmitted_probes + + projection_b * projected_factor + ) + + return exit_waves, error + + def _forward( + self, + current_object, + current_probe, + amplitudes, + exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic forward operator. + Calls _overlap_projection() and the appropriate _fourier_projection(). + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + amplitudes: np.ndarray + Normalized measured amplitudes + exit_waves: np.ndarray + previously estimated exit waves + use_projection_scheme: bool, + If True, use generalized projection update + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + ( + propagated_probes, + object_patches, + transmitted_probes, + ) = self._overlap_projection(current_object, current_probe) + + if use_projection_scheme: + exit_waves, error = self._projection_sets_fourier_projection( + amplitudes, + transmitted_probes, + exit_waves, + projection_a, + projection_b, + projection_c, + ) + + else: + exit_waves, error = self._gradient_descent_fourier_projection( + amplitudes, transmitted_probes + ) + + return propagated_probes, object_patches, transmitted_probes, exit_waves, error + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + for s in reversed(range(self._num_slices)): + probe = propagated_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2 + ) + + if self._object_type == "potential": + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves[:, i_probe] + ) + ) + ) + else: + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe] + ) + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] += object_update * probe_normalization + + # back-transmit + exit_waves *= xp.expand_dims(xp.conj(obj), axis=1) # / xp.abs(obj) ** 2 + + if s > 0: + # back-propagate + exit_waves = self._propagate_array( + exit_waves, xp.conj(self._propagator_arrays[s - 1]) + ) + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += ( + step_size + * xp.sum( + exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + # careful not to modify exit_waves in-place for projection set methods + exit_waves_copy = exit_waves.copy() + for s in reversed(range(self._num_slices)): + probe = propagated_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2 + ) + + if self._object_type == "potential": + object_update += self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves_copy[:, i_probe] + ) + ) + else: + object_update += self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] = object_update * probe_normalization + + # back-transmit + exit_waves_copy *= xp.expand_dims( + xp.conj(obj), axis=1 + ) # / xp.abs(obj) ** 2 + + if s > 0: + # back-propagate + exit_waves_copy = self._propagate_array( + exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) + ) + + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + exit_waves_copy, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + use_projection_scheme: bool, + step_size: float, + normalization_min: float, + fix_probe: bool, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + use_projection_scheme: bool, + If True, use generalized projection update + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + + if use_projection_scheme: + current_object, current_probe = self._projection_sets_adjoint( + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + normalization_min, + fix_probe, + ) + else: + current_object, current_probe = self._gradient_descent_adjoint( + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ) + + return current_object, current_probe + + def _position_correction( + self, + current_object, + current_probe, + transmitted_probes, + amplitudes, + current_positions, + positions_step_size, + constrain_position_distance, + ): + """ + Position correction using estimated intensity gradient. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe:np.ndarray + fractionally-shifted probes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + amplitudes: np.ndarray + Measured amplitudes + current_positions: np.ndarray + Current positions estimate + positions_step_size: float + Positions step size + constrain_position_distance: float + Distance to constrain position correction within original + field of view in A + + Returns + -------- + updated_positions: np.ndarray + Updated positions estimate + """ + + xp = self._xp + + # Intensity gradient + exit_waves_fft = xp.fft.fft2(transmitted_probes) + exit_waves_fft_conj = xp.conj(exit_waves_fft) + estimated_intensity = xp.abs(exit_waves_fft) ** 2 + measured_intensity = amplitudes**2 + + flat_shape = (transmitted_probes.shape[0], -1) + difference_intensity = (measured_intensity - estimated_intensity).reshape( + flat_shape + ) + + # Computing perturbed exit waves one at a time to save on memory + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + # dx + obj_rolled_patches = complex_object[ + :, + (self._vectorized_patch_indices_row + 1) % self._object_shape[0], + self._vectorized_patch_indices_col, + ] + + propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) + propagated_probes_perturbed[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes_perturbed = ( + obj_rolled_patches[s] * propagated_probes_perturbed[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes_perturbed[s + 1] = self._propagate_array( + transmitted_probes_perturbed, self._propagator_arrays[s] + ) + + exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) + + # dy + obj_rolled_patches = complex_object[ + :, + self._vectorized_patch_indices_row, + (self._vectorized_patch_indices_col + 1) % self._object_shape[1], + ] + + propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) + propagated_probes_perturbed[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes_perturbed = ( + obj_rolled_patches[s] * propagated_probes_perturbed[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes_perturbed[s + 1] = self._propagate_array( + transmitted_probes_perturbed, self._propagator_arrays[s] + ) + + exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) + + partial_intensity_dx = 2 * xp.real( + exit_waves_dx_fft * exit_waves_fft_conj + ).reshape(flat_shape) + partial_intensity_dy = 2 * xp.real( + exit_waves_dy_fft * exit_waves_fft_conj + ).reshape(flat_shape) + + coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) + + # positions_update = xp.einsum( + # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity + # ) + + coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) + positions_update = ( + xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) + @ coefficients_matrix_T + @ difference_intensity[..., None] + ) + + if constrain_position_distance is not None: + constrain_position_distance /= xp.sqrt( + self.sampling[0] ** 2 + self.sampling[1] ** 2 + ) + x1 = (current_positions - positions_step_size * positions_update[..., 0])[ + :, 0 + ] + y1 = (current_positions - positions_step_size * positions_update[..., 0])[ + :, 1 + ] + x0 = self._positions_px_initial[:, 0] + y0 = self._positions_px_initial[:, 1] + if self._rotation_best_transpose: + x0, y0 = xp.array([y0, x0]) + x1, y1 = xp.array([y1, x1]) + + if self._rotation_best_rad is not None: + rotation_angle = self._rotation_best_rad + x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( + -rotation_angle + ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) + x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( + -rotation_angle + ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) + + outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( + x1 < (xp.min(x0) - constrain_position_distance) + ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( + y1 < (xp.min(y0) - constrain_position_distance) + ) > 0 + + positions_update[..., 0][outlier_ind] = 0 + + current_positions -= positions_step_size * positions_update[..., 0] + + return current_positions + + def _probe_center_of_mass_constraint(self, current_probe): + """ + Ptychographic center of mass constraint. + Used for centering corner-centered probe intensity. + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + probe_intensity = xp.abs(current_probe[0]) ** 2 + + probe_x0, probe_y0 = get_CoM( + probe_intensity, device=self._device, corner_centered=True + ) + shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) + + return shifted_probe + + def _probe_orthogonalization_constraint(self, current_probe): + """ + Ptychographic probe-orthogonalization constraint. + Used to ensure mixed states are orthogonal to each other. + Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Orthogonalized probe estimate + """ + xp = self._xp + n_probes = self._num_probes + + # compute upper half of P* @ P + pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) + + for i in range(n_probes): + for j in range(i, n_probes): + pairwise_dot_product[i, j] = xp.sum( + current_probe[i].conj() * current_probe[j] + ) + + # compute eigenvectors (effectively cheaper way of computing V* from SVD) + _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") + current_probe = xp.tensordot(evecs.T, current_probe, axes=1) + + # sort by real-space intensity + intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) + intensities_order = xp.argsort(intensities, axis=None)[::-1] + return current_probe[intensities_order] + + def _object_butterworth_constraint( + self, current_object, q_lowpass, q_highpass, butterworth_order + ): + """ + 2D Butterworth filter + Used for low/high-pass filtering object. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qya, qxa = xp.meshgrid(qy, qx) + qra = xp.sqrt(qxa**2 + qya**2) + + env = xp.ones_like(qra) + if q_highpass: + env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) + if q_lowpass: + env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) + + current_object_mean = xp.mean(current_object) + current_object -= current_object_mean + current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env[None]) + current_object += current_object_mean + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_kz_regularization_constraint( + self, current_object, kz_regularization_gamma + ): + """ + Arctan regularization filter + + Parameters + -------- + current_object: np.ndarray + Current object estimate + kz_regularization_gamma: float + Slice regularization strength + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + + current_object = xp.pad( + current_object, pad_width=((1, 0), (0, 0), (0, 0)), mode="constant" + ) + + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) + + kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] + + qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") + qz2 = qza**2 * kz_regularization_gamma**2 + qr2 = qxa**2 + qya**2 + + w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) + + current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) + current_object = current_object[1:] + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_identical_slices_constraint(self, current_object): + """ + Strong regularization forcing all slices to be identical + + Parameters + -------- + current_object: np.ndarray + Current object estimate + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + object_mean = current_object.mean(0, keepdims=True) + current_object[:] = object_mean + + return current_object + + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + + def _constraints( + self, + current_object, + current_probe, + current_positions, + fix_com, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + fix_probe_aperture, + initial_probe_aperture, + fix_positions, + global_affine_transformation, + gaussian_filter, + gaussian_filter_sigma, + butterworth_filter, + q_lowpass, + q_highpass, + butterworth_order, + kz_regularization_filter, + kz_regularization_gamma, + identical_slices, + object_positivity, + shrinkage_rad, + object_mask, + pure_phase_object, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + orthogonalize_probe, + ): + """ + Ptychographic constraints operator. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + current_positions: np.ndarray + Current positions estimate + fix_com: bool + If True, probe CoM is fixed to the center + fit_probe_aberrations: bool + If True, fits the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + constrain_probe_amplitude: bool + If True, probe amplitude is constrained by top hat function + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool + If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_probe_aperture: bool + If True, probe Fourier amplitude is replaced by initial_probe_aperture + initial_probe_aperture: np.ndarray + Initial probe aperture to use in replacing probe Fourier amplitude + fix_positions: bool + If True, positions are not updated + gaussian_filter: bool + If True, applies real-space gaussian filter in A + gaussian_filter_sigma: float + Standard deviation of gaussian kernel + butterworth_filter: bool + If True, applies fourier-space butterworth filter + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter: bool + If True, applies fourier-space arctan regularization filter + kz_regularization_gamma: float + Slice regularization strength + identical_slices: bool + If True, forces all object slices to be identical + object_positivity: bool + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + object_mask: np.ndarray (boolean) + If not None, used to calculate additional shrinkage using masked-mean of object + pure_phase_object: bool + If True, object amplitude is set to unity + tv_denoise_chambolle: bool + If True, performs TV denoising along z + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: bool + if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + orthogonalize_probe: bool + If True, probe will be orthogonalized + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + constrained_probe: np.ndarray + Constrained probe estimate + constrained_positions: np.ndarray + Constrained positions estimate + """ + + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object + ) + + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + + if identical_slices: + current_object = self._object_identical_slices_constraint(current_object) + elif kz_regularization_filter: + current_object = self._object_kz_regularization_constraint( + current_object, kz_regularization_gamma + ) + elif tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) + elif tv_denoise_chambolle: + current_object = self._object_denoise_tv_chambolle( + current_object, + tv_denoise_weight_chambolle, + axis=0, + pad_object=tv_denoise_pad_chambolle, + ) + + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + if self._object_type == "complex": + current_object = self._object_threshold_constraint( + current_object, pure_phase_object + ) + elif object_positivity: + current_object = self._object_positivity_constraint(current_object) + + if fix_com: + current_probe = self._probe_center_of_mass_constraint(current_probe) + + # These constraints don't _really_ make sense for mixed-state + if fix_probe_aperture: + raise NotImplementedError() + elif constrain_probe_fourier_amplitude: + raise NotImplementedError() + if fit_probe_aberrations: + raise NotImplementedError() + if constrain_probe_amplitude: + raise NotImplementedError() + + if orthogonalize_probe: + current_probe = self._probe_orthogonalization_constraint(current_probe) + + if not fix_positions: + current_positions = self._positions_center_of_mass_constraint( + current_positions + ) + + if global_affine_transformation: + current_positions = self._positions_affine_transformation_constraint( + self._positions_px_initial, current_positions + ) + + return current_object, current_probe, current_positions + + def reconstruct( + self, + max_iter: int = 64, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_com: bool = True, + orthogonalize_probe: bool = True, + fix_probe_iter: int = 0, + fix_probe_aperture_iter: int = 0, + constrain_probe_amplitude_iter: int = 0, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions_iter: int = np.inf, + constrain_position_distance: float = None, + global_affine_transformation: bool = True, + gaussian_filter_sigma: float = None, + gaussian_filter_iter: int = np.inf, + fit_probe_aberrations_iter: int = 0, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + butterworth_filter_iter: int = np.inf, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + kz_regularization_filter_iter: int = np.inf, + kz_regularization_gamma: Union[float, np.ndarray] = None, + identical_slices_iter: int = 0, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + pure_phase_object_iter: int = 0, + tv_denoise_iter_chambolle=np.inf, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, + switch_object_iter: int = np.inf, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + max_iter: int, optional + Maximum number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + fix_com: bool, optional + If True, fixes center of mass of probe + fix_probe_iter: int, optional + Number of iterations to run with a fixed probe before updating probe estimate + fix_probe_aperture_iter: int, optional + Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate + constrain_probe_amplitude_iter: int, optional + Number of iterations to run while constraining the real-space probe with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude_iter: int, optional + Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions_iter: int, optional + Number of iterations to run with fixed positions before updating positions estimate + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter_iter: int, optional + Number of iterations to run using object smoothness constraint + fit_probe_aberrations_iter: int, optional + Number of iterations to run while fitting the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + butterworth_filter_iter: int, optional + Number of iterations to run using high-pass butteworth filter + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter_iter: int, optional + Number of iterations to run using kz regularization filter + kz_regularization_gamma, float, optional + kz regularization strength + identical_slices_iter: int, optional + Number of iterations to run using identical slices + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + pure_phase_object_iter: int, optional + Number of iterations where object amplitude is set to unity + tv_denoise_iter_chambolle: bool + Number of iterations with TV denoisining + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: bool + if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + switch_object_iter: int, optional + Iteration to switch object type between 'complex' and 'potential' or between + 'potential' and 'complex' + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + + Returns + -------- + self: MultislicePtychographicReconstruction + Self to accommodate chaining + """ + asnumpy = self._asnumpy + xp = self._xp + + # Reconstruction method + + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): + raise ValueError( + ( + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." + ) + ) + + use_projection_scheme = True + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c + step_size = None + elif ( + reconstruction_method == "DM_AP" + or reconstruction_method == "difference-map_alternating-projections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = 1 + projection_c = 1 + reconstruction_parameter + step_size = None + elif ( + reconstruction_method == "RAAR" + or reconstruction_method == "relaxed-averaged-alternating-reflections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = 1 - 2 * reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "RRR" + or reconstruction_method == "relax-reflect-reflect" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: + raise ValueError("reconstruction_parameter must be between 0-2.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "SUPERFLIP" + or reconstruction_method == "charge-flipping" + ): + use_projection_scheme = True + projection_a = 0 + projection_b = 1 + projection_c = 2 + reconstruction_parameter = None + step_size = None + elif ( + reconstruction_method == "GD" or reconstruction_method == "gradient-descent" + ): + use_projection_scheme = False + projection_a = None + projection_b = None + projection_c = None + reconstruction_parameter = None + else: + raise ValueError( + ( + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " + "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " + "'RRR' (or 'relax-reflect-reflect'), " + "'SUPERFLIP' (or 'charge-flipping'), " + f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." + ) + ) + + if self._verbose: + if switch_object_iter > max_iter: + first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " + else: + switch_object_type = ( + "complex" if self._object_type == "potential" else "potential" + ) + first_line = ( + f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " + f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " + ) + if max_batch_size is not None: + if use_projection_scheme: + raise ValueError( + ( + "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " + "Use reconstruction_method='GD' or set max_batch_size=None." + ) + ) + else: + print( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}, " + f"in batches of max {max_batch_size} measurements." + ) + ) + + else: + if reconstruction_parameter is not None: + if np.array(reconstruction_parameter).shape == (3,): + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." + ) + ) + else: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." + ) + ) + else: + if step_size is not None: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min}." + ) + ) + else: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}." + ) + ) + + # Batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + unshuffled_indices = np.zeros_like(shuffled_indices) + + if max_batch_size is not None: + xp.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + # initialization + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + if reset: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probe = self._probe_initial.copy() + self._positions_px = self._positions_px_initial.copy() + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + self._exit_waves = None + self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf + elif reset is None: + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + else: + self.error_iterations = [] + self._exit_waves = None + + # main loop + for a0 in tqdmnd( + max_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if a0 == switch_object_iter: + if self._object_type == "potential": + self._object_type = "complex" + self._object = xp.exp(1j * self._object) + elif self._object_type == "complex": + self._object_type = "potential" + self._object = xp.angle(self._object) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + unshuffled_indices[shuffled_indices] = np.arange( + self._num_diffraction_patterns + ) + positions_px = self._positions_px.copy()[shuffled_indices] + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[shuffled_indices[start:end]] + + # forward operator + ( + propagated_probes, + object_patches, + self._transmitted_probes, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + self._probe, + amplitudes, + self._exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + self._object, self._probe = self._adjoint( + self._object, + self._probe, + object_patches, + propagated_probes, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=a0 < fix_probe_iter, + ) + + # position correction + if a0 >= fix_positions_iter: + positions_px[start:end] = self._position_correction( + self._object, + self._probe[0], + self._transmitted_probes[:, 0], + amplitudes, + self._positions_px, + positions_step_size, + constrain_position_distance, + ) + + error += batch_error + + # Normalize Error + error /= self._mean_diffraction_intensity * self._num_diffraction_patterns + + # constraints + self._positions_px = positions_px.copy()[unshuffled_indices] + self._object, self._probe, self._positions_px = self._constraints( + self._object, + self._probe, + self._positions_px, + fix_com=fix_com and a0 >= fix_probe_iter, + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter + and a0 >= fix_probe_iter, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=self._probe_initial_aperture, + fix_positions=a0 < fix_positions_iter, + global_affine_transformation=global_affine_transformation, + gaussian_filter=a0 < gaussian_filter_iter + and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=a0 < butterworth_filter_iter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + kz_regularization_filter=a0 < kz_regularization_filter_iter + and kz_regularization_gamma is not None, + kz_regularization_gamma=kz_regularization_gamma[a0] + if kz_regularization_gamma is not None + and isinstance(kz_regularization_gamma, np.ndarray) + else kz_regularization_gamma, + identical_slices=a0 < identical_slices_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 + else None, + pure_phase_object=a0 < pure_phase_object_iter + and self._object_type == "complex", + tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + orthogonalize_probe=orthogonalize_probe, + ) + + self.error_iterations.append(error.item()) + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + return self + + def _visualize_last_iteration_figax( + self, + fig, + object_ax, + convergence_ax, + cbar: bool, + padding: int = 0, + **kwargs, + ): + """ + Displays last reconstructed object on a given fig/ax. + + Parameters + -------- + fig: Figure + Matplotlib figure object_ax lives in + object_ax: Axes + Matplotlib axes to plot reconstructed object in + convergence_ax: Axes, optional + Matplotlib axes to plot convergence plot in + cbar: bool, optional + If true, displays a colorbar + padding : int, optional + Pixels to pad by post rotating-cropping object + """ + cmap = kwargs.pop("cmap", "magma") + + if self._object_type == "complex": + obj = np.angle(self.object) + else: + obj = self.object + + rotated_object = self._crop_rotate_object_fov( + np.sum(obj, axis=0), padding=padding + ) + rotated_shape = rotated_object.shape + + extent = [ + 0, + self.sampling[1] * rotated_shape[1], + self.sampling[0] * rotated_shape[0], + 0, + ] + + im = object_ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + + if cbar: + divider = make_axes_locatable(object_ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if convergence_ax is not None and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + errors = self.error_iterations + + convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + padding: int, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + + """ + figsize = kwargs.pop("figsize", (8, 5)) + cmap = kwargs.pop("cmap", "magma") + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + + if self._object_type == "complex": + obj = np.angle(self.object) + else: + obj = self.object + + rotated_object = self._crop_rotate_object_fov( + np.sum(obj, axis=0), padding=padding + ) + rotated_shape = rotated_object.shape + + extent = [ + 0, + self.sampling[1] * rotated_shape[1], + self.sampling[0] * rotated_shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=2, + height_ratios=[4, 1], + hspace=0.15, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=1, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + if plot_probe or plot_fourier_probe: + # Object + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Probe + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + + ax = fig.add_subplot(spec[0, 1]) + if plot_fourier_probe: + probe_array = Complex2RGB( + self.probe_fourier[0], chroma_boost=chroma_boost + ) + ax.set_title("Reconstructed Fourier probe[0]") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + self.probe[0], power=2, chroma_boost=chroma_boost + ) + ax.set_title("Reconstructed probe[0] intensity") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + else: + ax = fig.add_subplot(spec[0]) + im = ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_convergence and hasattr(self, "error_iterations"): + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + errors = np.array(self.error_iterations) + if plot_probe: + ax = fig.add_subplot(spec[1, :]) + else: + ax = fig.add_subplot(spec[1]) + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + + def _visualize_all_iterations( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + iterations_grid: Tuple[int, int], + padding: int, + **kwargs, + ): + """ + Displays all reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + """ + asnumpy = self._asnumpy + + if not hasattr(self, "object_iterations"): + raise ValueError( + ( + "Object and probe iterations were not saved during reconstruction. " + "Please re-run using store_iterations=True." + ) + ) + + if iterations_grid == "auto": + num_iter = len(self.error_iterations) + + if num_iter == 1: + return self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + elif plot_probe or plot_fourier_probe: + iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) + else: + iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) + else: + if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: + raise ValueError() + + auto_figsize = ( + (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) + if plot_convergence + else (3 * iterations_grid[1], 3 * iterations_grid[0]) + ) + figsize = kwargs.pop("figsize", auto_figsize) + cmap = kwargs.pop("cmap", "magma") + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + + errors = np.array(self.error_iterations) + + objects = [] + object_type = [] + + for obj in self.object_iterations: + if np.iscomplexobj(obj): + obj = np.angle(obj) + object_type.append("phase") + else: + object_type.append("potential") + objects.append( + self._crop_rotate_object_fov(np.sum(obj, axis=0), padding=padding) + ) + + if plot_probe or plot_fourier_probe: + total_grids = (np.prod(iterations_grid) / 2).astype("int") + probes = self.probe_iterations + else: + total_grids = np.prod(iterations_grid) + max_iter = len(objects) - 1 + grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) + + extent = [ + 0, + self.sampling[1] * objects[0].shape[1], + self.sampling[0] * objects[0].shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=2) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + grid = ImageGrid( + fig, + spec[0], + nrows_ncols=(1, iterations_grid[1]) + if (plot_probe or plot_fourier_probe) + else iterations_grid, + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + im = ax.imshow( + objects[grid_range[n]], + extent=extent, + cmap=cmap, + **kwargs, + ) + ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if cbar: + grid.cbar_axes[n].colorbar(im) + + if plot_probe or plot_fourier_probe: + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + grid = ImageGrid( + fig, + spec[1], + nrows_ncols=(1, iterations_grid[1]), + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + if plot_fourier_probe: + probe_array = Complex2RGB( + asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]][0] + ) + ), + chroma_boost=chroma_boost, + ) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + probes[grid_range[n]][0], + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title(f"Iter: {grid_range[n]} probe[0]") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + add_colorbar_arg( + grid.cbar_axes[n], + chroma_boost=chroma_boost, + ) + + if plot_convergence: + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + if plot_probe: + ax2 = fig.add_subplot(spec[2]) + else: + ax2 = fig.add_subplot(spec[1]) + ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax2.set_ylabel("NMSE") + ax2.set_xlabel("Iteration number") + ax2.yaxis.tick_right() + + spec.tight_layout(fig) + + def visualize( + self, + fig=None, + iterations_grid: Tuple[int, int] = None, + plot_convergence: bool = True, + plot_probe: bool = True, + plot_fourier_probe: bool = False, + cbar: bool = True, + padding: int = 0, + **kwargs, + ): + """ + Displays reconstructed object and probe. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + + if iterations_grid is None: + self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + else: + self._visualize_all_iterations( + fig=fig, + plot_convergence=plot_convergence, + iterations_grid=iterations_grid, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + return self + + def show_fourier_probe( + self, probe=None, scalebar=True, pixelsize=None, pixelunits=None, **kwargs + ): + """ + Plot probe in fourier space + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses the `probe_fourier` property + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A^-1 + pixelsize: float, optional + default is probe reciprocal sampling + """ + asnumpy = self._asnumpy + + if probe is None: + probe = list(self.probe_fourier) + else: + if isinstance(probe, np.ndarray) and probe.ndim == 2: + probe = [probe] + probe = [asnumpy(self._return_fourier_probe(pr)) for pr in probe] + + if pixelsize is None: + pixelsize = self._reciprocal_sampling[1] + if pixelunits is None: + pixelunits = r"$\AA^{-1}$" + + chroma_boost = kwargs.pop("chroma_boost", 2) + + show_complex( + probe if len(probe) > 1 else probe[0], + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=False, + chroma_boost=chroma_boost, + **kwargs, + ) + + def show_transmitted_probe( + self, + plot_fourier_probe: bool = False, + **kwargs, + ): + """ + Plots the min, max, and mean transmitted probe after propagation and transmission. + + Parameters + ---------- + plot_fourier_probe: boolean, optional + If True, the transmitted probes are also plotted in Fourier space + kwargs: + Passed to show_complex + """ + + xp = self._xp + asnumpy = self._asnumpy + + transmitted_probe_intensities = xp.sum( + xp.abs(self._transmitted_probes[:, 0]) ** 2, axis=(-2, -1) + ) + min_intensity_transmitted = self._transmitted_probes[ + xp.argmin(transmitted_probe_intensities), 0 + ] + max_intensity_transmitted = self._transmitted_probes[ + xp.argmax(transmitted_probe_intensities), 0 + ] + mean_transmitted = self._transmitted_probes[:, 0].mean(0) + probes = [ + asnumpy(self._return_centered_probe(probe)) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + title = [ + "Mean Transmitted Probe", + "Min Intensity Transmitted Probe", + "Max Intensity Transmitted Probe", + ] + + if plot_fourier_probe: + bottom_row = [ + asnumpy(self._return_fourier_probe(probe)) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + probes = [probes, bottom_row] + + title += [ + "Mean Transmitted Fourier Probe", + "Min Intensity Transmitted Fourier Probe", + "Max Intensity Transmitted Fourier Probe", + ] + + title = kwargs.get("title", title) + show_complex( + probes, + title=title, + **kwargs, + ) + + def show_slices( + self, + ms_object=None, + cbar: bool = True, + common_color_scale: bool = True, + padding: int = 0, + num_cols: int = 3, + **kwargs, + ): + """ + Displays reconstructed slices of object + + Parameters + -------- + ms_object: nd.array, optional + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + padding: int, optional + Padding to leave uncropped + num_cols: int, optional + Number of GridSpec columns + """ + + if ms_object is None: + ms_object = self._object + + rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + rotated_shape = rotated_object.shape + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + + extent = [ + 0, + self.sampling[1] * rotated_shape[2], + self.sampling[0] * rotated_shape[1], + 0, + ] + + num_rows = np.ceil(self._num_slices / num_cols).astype("int") + wspace = 0.35 if cbar else 0.15 + + axsize = kwargs.pop("axsize", (3, 3)) + cmap = kwargs.pop("cmap", "magma") + vmin = np.min(rotated_object) if common_color_scale else None + vmax = np.max(rotated_object) if common_color_scale else None + vmin = kwargs.pop("vmin", vmin) + vmax = kwargs.pop("vmax", vmax) + + spec = GridSpec( + ncols=num_cols, + nrows=num_rows, + hspace=0.15, + wspace=wspace, + ) + + figsize = (axsize[0] * num_cols, axsize[1] * num_rows) + fig = plt.figure(figsize=figsize) + + for flat_index, obj_slice in enumerate(rotated_object): + row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) + ax = fig.add_subplot(spec[row_index, col_index]) + im = ax.imshow( + obj_slice, + cmap=cmap, + vmin=vmin, + vmax=vmax, + extent=extent, + **kwargs, + ) + + ax.set_title(f"Slice index: {flat_index}") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if row_index < num_rows - 1: + ax.set_xticks([]) + else: + ax.set_xlabel("y [A]") + + if col_index > 0: + ax.set_yticks([]) + else: + ax.set_ylabel("x [A]") + + spec.tight_layout(fig) + + def show_depth( + self, + x1: float, + x2: float, + y1: float, + y2: float, + specify_calibrated: bool = False, + gaussian_filter_sigma: float = None, + ms_object=None, + cbar: bool = False, + aspect: float = None, + plot_line_profile: bool = False, + **kwargs, + ): + """ + Displays line profile depth section + + Parameters + -------- + x1, x2, y1, y2: floats (pixels) + Line profile for depth section runs from (x1,y1) to (x2,y2) + Specified in pixels unless specify_calibrated is True + specify_calibrated: bool (optional) + If True, specify x1, x2, y1, y2 in A values instead of pixels + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + ms_object: np.array + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + aspect: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken + """ + if ms_object is not None: + ms_obj = ms_object + else: + ms_obj = self.object_cropped + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + if x2 == x1: + angle = 0 + elif y2 == y1: + angle = np.pi / 2 + else: + angle = np.arctan((x2 - x1) / (y2 - y1)) + + x0 = ms_obj.shape[1] / 2 + y0 = ms_obj.shape[2] / 2 + + if ( + x1 > ms_obj.shape[1] + or x2 > ms_obj.shape[1] + or y1 > ms_obj.shape[2] + or y2 > ms_obj.shape[2] + ): + raise ValueError("depth section must be in field of view of object") + + from py4DSTEM.process.phase.utils import rotate_point + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + int(x1_0), + axis=1, + ) + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + if gaussian_filter_sigma is not None: + from scipy.ndimage import gaussian_filter + + gaussian_filter_sigma /= self.sampling[0] + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)] + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + + figsize = kwargs.pop("figsize", (6, 6)) + if not plot_line_profile: + fig, ax = plt.subplots(figsize=figsize) + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + else: + extent2 = [ + 0, + self.sampling[1] * ms_obj.shape[2], + self.sampling[0] * ms_obj.shape[1], + 0, + ] + fig, ax = plt.subplots(2, 1, figsize=figsize) + ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) + ax[0].plot( + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], + color="red", + ) + ax[0].set_xlabel("y [A]") + ax[0].set_ylabel("x [A]") + ax[0].set_title("Multislice depth profile location") + + im = ax[1].imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax[1].set_aspect(aspect) + ax[1].set_xlabel("r [A]") + ax[1].set_ylabel("z [A]") + ax[1].set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax[1]) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + plt.tight_layout() + + def tune_num_slices_and_thicknesses( + self, + num_slices_guess=None, + thicknesses_guess=None, + num_slices_step_size=1, + thicknesses_step_size=20, + num_slices_values=3, + num_thicknesses_values=3, + update_defocus=False, + max_iter=5, + plot_reconstructions=True, + plot_convergence=True, + return_values=False, + **kwargs, + ): + """ + Run reconstructions over a parameters space of number of slices + and slice thicknesses. Should be run after the preprocess step. + + Parameters + ---------- + num_slices_guess: float, optional + initial starting guess for number of slices, rounds to nearest integer + if None, uses current initialized values + thicknesses_guess: float (A), optional + initial starting guess for thicknesses of slices assuming same + thickness for each slice + if None, uses current initialized values + num_slices_step_size: float, optional + size of change of number of slices for each step in parameter space + thicknesses_step_size: float (A), optional + size of change of slice thicknesses for each step in parameter space + num_slices_values: int, optional + number of number of slice values to test, must be >= 1 + num_thicknesses_values: int,optional + number of thicknesses values to test, must be >= 1 + update_defocus: bool, optional + if True, updates defocus based on estimated total thickness + max_iter: int, optional + number of iterations to run in ptychographic reconstruction + plot_reconstructions: bool, optional + if True, plot phase of reconstructed objects + plot_convergence: bool, optional + if True, plots error for each iteration for each reconstruction + return_values: bool, optional + if True, returns objects, convergence + + Returns + ------- + objects: list + reconstructed objects + convergence: np.ndarray + array of convergence values from reconstructions + """ + + # calculate number of slices and thicknesses values to test + if num_slices_guess is None: + num_slices_guess = self._num_slices + if thicknesses_guess is None: + thicknesses_guess = np.mean(self._slice_thicknesses) + + if num_slices_values == 1: + num_slices_step_size = 0 + + if num_thicknesses_values == 1: + thicknesses_step_size = 0 + + num_slices = np.linspace( + num_slices_guess - num_slices_step_size * (num_slices_values - 1) / 2, + num_slices_guess + num_slices_step_size * (num_slices_values - 1) / 2, + num_slices_values, + ) + + thicknesses = np.linspace( + thicknesses_guess + - thicknesses_step_size * (num_thicknesses_values - 1) / 2, + thicknesses_guess + + thicknesses_step_size * (num_thicknesses_values - 1) / 2, + num_thicknesses_values, + ) + + if return_values: + convergence = [] + objects = [] + + # current initialized values + current_verbose = self._verbose + current_num_slices = self._num_slices + current_thicknesses = self._slice_thicknesses + current_rotation_deg = self._rotation_best_rad * 180 / np.pi + current_transpose = self._rotation_best_transpose + current_defocus = -self._polar_parameters["C10"] + + # Gridspec to plot on + if plot_reconstructions: + if plot_convergence: + spec = GridSpec( + ncols=num_thicknesses_values, + nrows=num_slices_values * 2, + height_ratios=[1, 1 / 4] * num_slices_values, + hspace=0.15, + wspace=0.35, + ) + figsize = kwargs.get( + "figsize", (4 * num_thicknesses_values, 5 * num_slices_values) + ) + else: + spec = GridSpec( + ncols=num_thicknesses_values, + nrows=num_slices_values, + hspace=0.15, + wspace=0.35, + ) + figsize = kwargs.get( + "figsize", (4 * num_thicknesses_values, 4 * num_slices_values) + ) + + fig = plt.figure(figsize=figsize) + + progress_bar = kwargs.pop("progress_bar", False) + # run loop and plot along the way + self._verbose = False + for flat_index, (slices, thickness) in enumerate( + tqdmnd(num_slices, thicknesses, desc="Tuning angle and defocus") + ): + slices = int(slices) + self._num_slices = slices + self._slice_thicknesses = np.tile(thickness, slices - 1) + self._probe = None + self._object = None + if update_defocus: + defocus = current_defocus + slices / 2 * thickness + self._polar_parameters["C10"] = -defocus + + self.preprocess( + plot_center_of_mass=False, + plot_rotation=False, + plot_probe_overlaps=False, + force_com_rotation=current_rotation_deg, + force_com_transpose=current_transpose, + ) + self.reconstruct( + reset=True, + store_iterations=True if plot_convergence else False, + max_iter=max_iter, + progress_bar=progress_bar, + **kwargs, + ) + + if plot_reconstructions: + row_index, col_index = np.unravel_index( + flat_index, (num_slices_values, num_thicknesses_values) + ) + + if plot_convergence: + object_ax = fig.add_subplot(spec[row_index * 2, col_index]) + convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) + self._visualize_last_iteration_figax( + fig, + object_ax=object_ax, + convergence_ax=convergence_ax, + cbar=True, + ) + convergence_ax.yaxis.tick_right() + else: + object_ax = fig.add_subplot(spec[row_index, col_index]) + self._visualize_last_iteration_figax( + fig, + object_ax=object_ax, + convergence_ax=None, + cbar=True, + ) + + object_ax.set_title( + f" num slices = {slices:.0f}, slices thickness = {thickness:.1f} A \n error = {self.error:.3e}" + ) + object_ax.set_xticks([]) + object_ax.set_yticks([]) + + if return_values: + objects.append(self.object) + convergence.append(self.error_iterations.copy()) + + # initialize back to pre-tuning values + self._probe = None + self._object = None + self._num_slices = current_num_slices + self._slice_thicknesses = np.tile(current_thicknesses, current_num_slices - 1) + self._polar_parameters["C10"] = -current_defocus + self.preprocess( + force_com_rotation=current_rotation_deg, + force_com_transpose=current_transpose, + plot_center_of_mass=False, + plot_rotation=False, + plot_probe_overlaps=False, + ) + self._verbose = current_verbose + + if plot_reconstructions: + spec.tight_layout(fig) + + if return_values: + return objects, convergence + + def _return_object_fft( + self, + obj=None, + ): + """ + Returns obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + """ + asnumpy = self._asnumpy + + if obj is None: + obj = self._object + + obj = asnumpy(obj) + if np.iscomplexobj(obj): + obj = np.angle(obj) + + obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) + return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index ceae66cd8..2e9fbd076 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -204,6 +204,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -261,6 +262,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -346,9 +349,7 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns ) # explicitly delete namespace @@ -429,6 +430,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) _probe = ( ComplexProbe( @@ -505,19 +510,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -540,23 +539,19 @@ def preprocess( axs[i].imshow( complex_probe_rgb[i], extent=probe_extent, - **kwargs, ) axs[i].set_ylabel("x [A]") axs[i].set_xlabel("y [A]") - axs[i].set_title(f"Initial Probe[{i}]") + axs[i].set_title(f"Initial probe[{i}] intensity") divider = make_axes_locatable(axs[i]) cax = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax, chroma_boost=chroma_boost) axs[-1].imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) axs[-1].scatter( self.positions[:, 1], @@ -568,7 +563,7 @@ def preprocess( axs[-1].set_xlabel("y [A]") axs[-1].set_xlim((extent[0], extent[1])) axs[-1].set_ylim((extent[2], extent[3])) - axs[-1].set_title("Object Field of View") + axs[-1].set_title("Object field of view") fig.tight_layout() @@ -1125,6 +1120,9 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, orthogonalize_probe, object_positivity, shrinkage_rad, @@ -1183,6 +1181,12 @@ def _constraints( Butterworth filter order. Smaller gives a smoother filter orthogonalize_probe: bool If True, probe will be orthogonalized + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool If True, clips negative potential values shrinkage_rad: float @@ -1213,6 +1217,11 @@ def _constraints( butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, tv_denoise_weight, tv_denoise_inner_iter + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( current_object, @@ -1281,6 +1290,7 @@ def reconstruct( constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, global_affine_transformation: bool = True, + constrain_position_distance: float = None, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, fit_probe_aberrations_iter: int = 0, @@ -1290,6 +1300,9 @@ def reconstruct( q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1353,6 +1366,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + constrain_position_distance: float + Distance to constrain position correction within original field of view in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1373,6 +1388,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -1575,6 +1596,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -1667,6 +1690,7 @@ def reconstruct( amplitudes, self._positions_px, positions_step_size, + constrain_position_distance, ) error += batch_error @@ -1707,6 +1731,9 @@ def reconstruct( q_highpass=q_highpass, butterworth_order=butterworth_order, orthogonalize_probe=orthogonalize_probe, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -1821,8 +1848,11 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -1915,29 +1945,31 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier[0], hue_start=hue_start, invert=invert + self.probe_fourier[0], + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe[0], hue_start=hue_start, invert=invert + self.probe[0], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title("Reconstructed probe[0]") + ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -1970,10 +2002,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2045,8 +2077,11 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2149,29 +2184,30 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), - hue_start=hue_start, - invert=invert, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]][0], hue_start=hue_start, invert=invert + probes[grid_range[n]][0], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title(f"Iter: {grid_range[n]} probe[0]") + ax.set_title(f"Iter: {grid_range[n]} probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2183,7 +2219,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -2280,11 +2316,14 @@ def show_fourier_probe( if pixelunits is None: pixelunits = r"$\AA^{-1}$" + chroma_boost = kwargs.pop("chroma_boost", 2) + show_complex( probe if len(probe) > 1 else probe[0], scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, + chroma_boost=chroma_boost, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index aee383675..4515590fe 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex @@ -29,6 +30,7 @@ spatial_frequencies, ) from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar +from scipy.ndimage import rotate warnings.simplefilter(action="always", category=UserWarning) @@ -78,6 +80,10 @@ class MultislicePtychographicReconstruction(PtychographicReconstruction): initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + theta_x: float + x tilt of propagator (in angles) + theta_y: float + y tilt of propagator (in angles) object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') @@ -109,6 +115,8 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + theta_x: float = 0, + theta_y: float = 0, object_type: str = "complex", verbose: bool = True, device: str = "cpu", @@ -189,6 +197,8 @@ def __init__( # Class-specific Metadata self._num_slices = num_slices self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y def _precompute_propagator_arrays( self, @@ -196,6 +206,8 @@ def _precompute_propagator_arrays( sampling: Tuple[float, float], energy: float, slice_thicknesses: Sequence[float], + theta_x: float, + theta_y: float, ): """ Precomputes propagator arrays complex wave-function will be convolved by, @@ -211,6 +223,10 @@ def _precompute_propagator_arrays( The electron energy of the wave functions in eV slice_thicknesses: Sequence[float] Array of slice thicknesses in A + theta_x: float + x tilt of propagator (in angles) + theta_y: float + y tilt of propagator (in angles) Returns ------- @@ -230,6 +246,10 @@ def _precompute_propagator_arrays( propagators = xp.empty( (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 ) + + theta_x = np.deg2rad(theta_x) + theta_y = np.deg2rad(theta_y) + for i, dz in enumerate(slice_thicknesses): propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) @@ -237,6 +257,12 @@ def _precompute_propagator_arrays( propagators[i] *= xp.exp( 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) ) + propagators[i] *= xp.exp( + 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) + ) + propagators[i] *= xp.exp( + 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) + ) return propagators @@ -279,6 +305,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -336,6 +363,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -421,9 +450,7 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns ) # explicitly delete namespace @@ -503,6 +530,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -559,6 +590,8 @@ def preprocess( self.sampling, self._energy, self._slice_thicknesses, + self._theta_x, + self._theta_y, ) # overlaps @@ -575,19 +608,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) # propagated @@ -599,10 +626,8 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -624,38 +649,34 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax1, chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[:, 1], @@ -667,7 +688,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -1449,6 +1470,111 @@ def _object_identical_slices_constraint(self, current_object): return current_object + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1481,9 +1607,12 @@ def _constraints( shrinkage_rad, object_mask, pure_phase_object, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, tv_denoise, - tv_denoise_weight, - tv_denoise_pad, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1548,12 +1677,19 @@ def _constraints( If not None, used to calculate additional shrinkage using masked-mean of object pure_phase_object: bool If True, object amplitude is set to unity - tv_denoise: bool + tv_denoise_chambolle: bool If True, performs TV denoising along z - tv_denoise_weight: float + tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad: bool + tv_denoise_pad_chambolle: bool if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1585,13 +1721,17 @@ def _constraints( current_object, kz_regularization_gamma ) elif tv_denoise: - if self._object_type == "complex": - raise NotImplementedError() + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) + elif tv_denoise_chambolle: current_object = self._object_denoise_tv_chambolle( current_object, - tv_denoise_weight, + tv_denoise_weight_chambolle, axis=0, - pad_object=tv_denoise_pad, + pad_object=tv_denoise_pad_chambolle, ) if shrinkage_rad > 0.0 or object_mask is not None: @@ -1690,9 +1830,12 @@ def reconstruct( shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, pure_phase_object_iter: int = 0, + tv_denoise_iter_chambolle=np.inf, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=True, tv_denoise_iter=np.inf, - tv_denoise_weight=None, - tv_denoise_pad=True, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, switch_object_iter: int = np.inf, store_iterations: bool = False, progress_bar: bool = True, @@ -1751,6 +1894,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + constrain_position_distance: float + Distance to constrain position correction within original field of view in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1785,12 +1930,19 @@ def reconstruct( If true, the potential mean outside the FOV is forced to zero at each iteration pure_phase_object_iter: int, optional Number of iterations where object amplitude is set to unity - tv_denoise_iter: bool + tv_denoise_iter_chambolle: bool Number of iterations with TV denoisining - tv_denoise_weight: float + tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad: bool + tv_denoise_pad_chambolle: bool if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising switch_object_iter: int, optional Iteration to switch object type between 'complex' and 'potential' or between 'potential' and 'complex' @@ -1987,6 +2139,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2123,7 +2277,7 @@ def reconstruct( and kz_regularization_gamma is not None, kz_regularization_gamma=kz_regularization_gamma[a0] if kz_regularization_gamma is not None - and type(kz_regularization_gamma) == np.ndarray + and isinstance(kz_regularization_gamma, np.ndarray) else kz_regularization_gamma, identical_slices=a0 < identical_slices_iter, object_positivity=object_positivity, @@ -2133,9 +2287,13 @@ def reconstruct( else None, pure_phase_object=a0 < pure_phase_object_iter and self._object_type == "complex", - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, - tv_denoise_weight=tv_denoise_weight, - tv_denoise_pad=tv_denoise_pad, + tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) @@ -2250,8 +2408,11 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -2347,29 +2508,29 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, power=2, chroma_boost=chroma_boost ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2402,10 +2563,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2477,8 +2638,11 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2583,29 +2747,28 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], power=2, chroma_boost=chroma_boost ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2617,7 +2780,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -2841,6 +3004,143 @@ def show_slices( spec.tight_layout(fig) + def show_depth( + self, + x1: float, + x2: float, + y1: float, + y2: float, + specify_calibrated: bool = False, + gaussian_filter_sigma: float = None, + ms_object=None, + cbar: bool = False, + aspect: float = None, + plot_line_profile: bool = False, + **kwargs, + ): + """ + Displays line profile depth section + + Parameters + -------- + x1, x2, y1, y2: floats (pixels) + Line profile for depth section runs from (x1,y1) to (x2,y2) + Specified in pixels unless specify_calibrated is True + specify_calibrated: bool (optional) + If True, specify x1, x2, y1, y2 in A values instead of pixels + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + ms_object: np.array + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + aspect: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken + """ + if ms_object is not None: + ms_obj = ms_object + else: + ms_obj = self.object_cropped + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + if x2 == x1: + angle = 0 + elif y2 == y1: + angle = np.pi / 2 + else: + angle = np.arctan((x2 - x1) / (y2 - y1)) + + x0 = ms_obj.shape[1] / 2 + y0 = ms_obj.shape[2] / 2 + + if ( + x1 > ms_obj.shape[1] + or x2 > ms_obj.shape[1] + or y1 > ms_obj.shape[2] + or y2 > ms_obj.shape[2] + ): + raise ValueError("depth section must be in field of view of object") + + from py4DSTEM.process.phase.utils import rotate_point + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + int(x1_0), + axis=1, + ) + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + if gaussian_filter_sigma is not None: + from scipy.ndimage import gaussian_filter + + gaussian_filter_sigma /= self.sampling[0] + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)] + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + figsize = kwargs.pop("figsize", (6, 6)) + if not plot_line_profile: + fig, ax = plt.subplots(figsize=figsize) + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + else: + extent2 = [ + 0, + self.sampling[1] * ms_obj.shape[2], + self.sampling[0] * ms_obj.shape[1], + 0, + ] + + fig, ax = plt.subplots(2, 1, figsize=figsize) + ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) + ax[0].plot( + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], + color="red", + ) + ax[0].set_xlabel("y [A]") + ax[0].set_ylabel("x [A]") + ax[0].set_title("Multislice depth profile location") + + im = ax[1].imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax[1].set_aspect(aspect) + ax[1].set_xlabel("r [A]") + ax[1].set_ylabel("z [A]") + ax[1].set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax[1]) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + plt.tight_layout() + def tune_num_slices_and_thicknesses( self, num_slices_guess=None, diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index b09d18ca7..32b0f6fd4 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import make_axes_locatable from py4DSTEM.visualize import show @@ -430,6 +431,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -474,6 +476,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -591,9 +595,7 @@ def preprocess( ], mean_diffraction_intensity_temp, ) = self._normalize_diffraction_intensities( - intensities, - com_fitted_x, - com_fitted_y, + intensities, com_fitted_x, com_fitted_y, crop_patterns ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -684,6 +686,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -807,19 +813,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) # propagated @@ -831,10 +831,8 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -856,38 +854,37 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[0, :, 1], @@ -899,7 +896,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -1679,6 +1676,111 @@ def _divergence_free_constraint(self, vector_field): return vector_field + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1710,6 +1812,9 @@ def _constraints( object_positivity, shrinkage_rad, object_mask, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1771,6 +1876,15 @@ def _constraints( If True, forces object to be positive shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration + object_mask: np.ndarray (boolean) + If not None, used to calculate additional shrinkage using masked-mean of object + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1822,6 +1936,31 @@ def _constraints( butterworth_order, ) + elif tv_denoise: + current_object[0] = self._object_denoise_tv_pylops( + current_object[0], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[1] = self._object_denoise_tv_pylops( + current_object[1], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[2] = self._object_denoise_tv_pylops( + current_object[2], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[3] = self._object_denoise_tv_pylops( + current_object[3], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object[0] = self._object_shrinkage_constraint( current_object[0], @@ -1913,6 +2052,9 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, collective_tilt_updates: bool = False, store_iterations: bool = False, progress_bar: bool = True, @@ -1998,6 +2140,15 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_tilt_updates: bool + if True perform collective tilt updates shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration store_iterations: bool, optional @@ -2171,12 +2322,13 @@ def reconstruct( self.error_iterations = [] self._probe = self._probe_initial.copy() self._positions_px_all = self._positions_px_initial_all.copy() + if hasattr(self, "_tf"): + del self._tf if use_projection_scheme: self._exit_waves = [None] * self._num_tilts else: self._exit_waves = None - elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2477,6 +2629,10 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter + and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) # Normalize Error Over Tilts @@ -2530,6 +2686,9 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) @@ -2929,7 +3088,7 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[-1, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() spec.tight_layout(fig) @@ -3133,7 +3292,7 @@ def show_object_fft( vmax = kwargs.pop("vmax", 1) power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 1f6be1c38..66cf46487 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable from py4DSTEM.visualize import show @@ -55,8 +56,8 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): The electron energy of the wave functions in eV num_slices: int Number of slices to use in the forward model - tilt_angles_deg: Sequence[float] - List of tilt angles in degrees, + tilt_orientation_matrices: Sequence[np.ndarray] + List of orientation matrices for each tilt semiangle_cutoff: float, optional Semiangle cutoff for the initial probe guess in mrad semiangle_cutoff_pixels: float, optional @@ -94,13 +95,14 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): """ # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_tilt_angles_deg") + _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices") + _swap_zxy_to_xyz = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) def __init__( self, energy: float, num_slices: int, - tilt_angles_deg: Sequence[float], + tilt_orientation_matrices: Sequence[np.ndarray], datacube: Sequence[DataCube] = None, semiangle_cutoff: float = None, semiangle_cutoff_pixels: float = None, @@ -122,22 +124,29 @@ def __init__( if device == "cpu": self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter, rotate, zoom + from scipy.ndimage import affine_transform, gaussian_filter, rotate, zoom self._gaussian_filter = gaussian_filter self._zoom = zoom self._rotate = rotate + self._affine_transform = affine_transform from scipy.special import erf self._erf = erf elif device == "gpu": self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter, rotate, zoom + from cupyx.scipy.ndimage import ( + affine_transform, + gaussian_filter, + rotate, + zoom, + ) self._gaussian_filter = gaussian_filter self._zoom = zoom self._rotate = rotate + self._affine_transform = affine_transform from cupyx.scipy.special import erf self._erf = erf @@ -156,7 +165,7 @@ def __init__( polar_parameters.update(kwargs) self._set_polar_parameters(polar_parameters) - num_tilts = len(tilt_angles_deg) + num_tilts = len(tilt_orientation_matrices) if initial_scan_positions is None: initial_scan_positions = [None] * num_tilts @@ -185,7 +194,7 @@ def __init__( # Class-specific Metadata self._num_slices = num_slices - self._tilt_angles_deg = tuple(tilt_angles_deg) + self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) self._num_tilts = num_tilts def _precompute_propagator_arrays( @@ -323,6 +332,29 @@ def _expand_sliced_object(self, array: np.ndarray, output_z): normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] + def _rotate_zxy_volume( + self, + volume_array, + rot_matrix, + ): + """ """ + + xp = self._xp + affine_transform = self._affine_transform + swap_zxy_to_xyz = self._swap_zxy_to_xyz + + volume = volume_array.copy() + volume_shape = xp.asarray(volume.shape) + tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix.T @ swap_zxy_to_xyz) + + in_center = (volume_shape - 1) / 2 + out_center = tf @ in_center + offset = in_center - out_center + + volume = affine_transform(volume, tf, offset=offset, order=3) + + return volume + def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, @@ -340,6 +372,7 @@ def preprocess( force_reciprocal_sampling: float = None, progress_bar: bool = True, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -384,6 +417,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -500,9 +535,7 @@ def preprocess( ], mean_diffraction_intensity_temp, ) = self._normalize_diffraction_intensities( - intensities, - com_fitted_x, - com_fitted_y, + intensities, com_fitted_x, com_fitted_y, crop_patterns ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -593,6 +626,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -663,15 +700,14 @@ def preprocess( # overlaps if object_fov_mask is None: probe_overlap_3D = xp.zeros_like(self._object) + old_rot_matrix = np.eye(3) # identity for tilt_index in np.arange(self._num_tilts): - current_angle_deg = self._tilt_angles_deg[tilt_index] - probe_overlap_3D = self._rotate( + rot_matrix = self._tilt_orientation_matrices[tilt_index] + + probe_overlap_3D = self._rotate_zxy_volume( probe_overlap_3D, - current_angle_deg, - axes=(0, 2), - reshape=False, - order=2, + rot_matrix @ old_rot_matrix.T, ) self._positions_px = self._positions_px_all[ @@ -691,14 +727,12 @@ def preprocess( ) probe_overlap_3D += probe_overlap[None] + old_rot_matrix = rot_matrix - probe_overlap_3D = self._rotate( - probe_overlap_3D, - -current_angle_deg, - axes=(0, 2), - reshape=False, - order=2, - ) + probe_overlap_3D = self._rotate_zxy_volume( + probe_overlap_3D, + old_rot_matrix.T, + ) probe_overlap_3D = self._gaussian_filter(probe_overlap_3D, 1.0) self._object_fov_mask = asnumpy( @@ -719,19 +753,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) # propagated @@ -743,10 +771,8 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -768,38 +794,37 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[0, :, 1], @@ -811,7 +836,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -1527,6 +1552,111 @@ def _object_butterworth_constraint( current_object += current_object_mean return xp.real(current_object) + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1555,6 +1685,9 @@ def _constraints( object_positivity, shrinkage_rad, object_mask, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1611,6 +1744,13 @@ def _constraints( Phase shift in radians to be subtracted from the potential at each iteration object_mask: np.ndarray (boolean) If not None, used to calculate additional shrinkage using masked-mean of object + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1634,6 +1774,12 @@ def _constraints( q_highpass, butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( @@ -1723,6 +1869,9 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, collective_tilt_updates: bool = False, store_iterations: bool = False, progress_bar: bool = True, @@ -1806,6 +1955,15 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_tilt_updates: bool + if True perform collective tilt updates shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration store_iterations: bool, optional @@ -1981,12 +2139,13 @@ def reconstruct( self.error_iterations = [] self._probe = self._probe_initial.copy() self._positions_px_all = self._positions_px_initial_all.copy() + if hasattr(self, "_tf"): + del self._tf if use_projection_scheme: self._exit_waves = [None] * self._num_tilts else: self._exit_waves = None - elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2018,17 +2177,17 @@ def reconstruct( tilt_indices = np.arange(self._num_tilts) np.random.shuffle(tilt_indices) + old_rot_matrix = np.eye(3) # identity + for tilt_index in tilt_indices: self._active_tilt_index = tilt_index tilt_error = 0.0 - self._object = self._rotate( + rot_matrix = self._tilt_orientation_matrices[self._active_tilt_index] + self._object = self._rotate_zxy_volume( self._object, - self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, + rot_matrix @ old_rot_matrix.T, ) object_sliced = self._project_sliced_object( @@ -2132,23 +2291,13 @@ def reconstruct( ) if collective_tilt_updates: - collective_object += self._rotate( - object_update, - -self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, + collective_object += self._rotate_zxy_volume( + object_update, rot_matrix.T ) else: self._object += object_update - self._object = self._rotate( - self._object, - -self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, - ) + old_rot_matrix = rot_matrix # Normalize Error tilt_error /= ( @@ -2203,8 +2352,14 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter + and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) + self._object = self._rotate_zxy_volume(self._object, old_rot_matrix.T) + # Normalize Error Over Tilts error /= self._num_tilts @@ -2251,6 +2406,9 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) @@ -2431,8 +2589,11 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) asnumpy = self._asnumpy @@ -2534,16 +2695,19 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, + power=2, + chroma_boost=chroma_boost, ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2556,7 +2720,10 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg( + ax_cb, + chroma_boost=chroma_boost, + ) else: ax = fig.add_subplot(spec[0]) im = ax.imshow( @@ -2585,10 +2752,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2672,8 +2839,11 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2788,29 +2958,30 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2822,7 +2993,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -3001,7 +3172,7 @@ def show_object_fft( vmax = kwargs.pop("vmax", 1) power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 7c5896b6a..74688fa0b 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -8,13 +8,18 @@ import matplotlib.pyplot as plt import numpy as np -from emdfile import Custom, tqdmnd +from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from matplotlib.gridspec import GridSpec -from py4DSTEM import DataCube +from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable +from py4DSTEM import Calibration, DataCube +from py4DSTEM.preprocess.utils import get_shifted_ar from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.utils import AffineTransform from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom +from py4DSTEM.visualize import show from scipy.linalg import polar +from scipy.optimize import minimize from scipy.special import comb try: @@ -24,6 +29,23 @@ warnings.simplefilter(action="always", category=UserWarning) +_aberration_names = { + (1, 0): "C1 ", + (1, 2): "stig ", + (2, 1): "coma ", + (2, 3): "trefoil ", + (3, 0): "C3 ", + (3, 2): "stig2 ", + (3, 4): "quadfoil ", + (4, 1): "coma2 ", + (4, 3): "trefoil2 ", + (4, 5): "pentafoil ", + (5, 0): "C5 ", + (5, 2): "stig3 ", + (5, 4): "quadfoil2 ", + (5, 6): "hexafoil ", +} + class ParallaxReconstruction(PhaseReconstruction): """ @@ -35,9 +57,6 @@ class ParallaxReconstruction(PhaseReconstruction): Input 4D diffraction pattern intensities energy: float The electron energy of the wave functions in eV - dp_mean: ndarray, optional - Mean diffraction pattern - If None, get_dp_mean() is used verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -73,6 +92,8 @@ def __init__( else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_save_defaults() + # Data self._datacube = datacube @@ -86,9 +107,78 @@ def __init__( def to_h5(self, group): """ Wraps datasets and metadata to write in emdfile classes, - notably ... + notably the (subpixel-)aligned BF. """ - raise NotImplementedError() + # instantiation metadata + self.metadata = Metadata( + name="instantiation_metadata", + data={ + "energy": self._energy, + "verbose": self._verbose, + "device": self._device, + "object_padding_px": self._object_padding_px, + "name": self.name, + }, + ) + + # preprocessing metadata + self.metadata = Metadata( + name="preprocess_metadata", + data={ + "scan_sampling": self._scan_sampling, + "wavelength": self._wavelength, + }, + ) + + # reconstruction metadata + recon_metadata = {"reconstruction_error": float(self._recon_error)} + + if hasattr(self, "aberration_C1"): + recon_metadata |= { + "aberration_rotation_QR": self.rotation_Q_to_R_rads, + "aberration_transpose": self.transpose_detected, + "aberration_C1": self.aberration_C1, + "aberration_A1x": self.aberration_A1x, + "aberration_A1y": self.aberration_A1y, + } + + if hasattr(self, "_kde_upsample_factor"): + recon_metadata |= { + "kde_upsample_factor": self._kde_upsample_factor, + } + self._subpixel_aligned_BF_emd = Array( + name="subpixel_aligned_BF", + data=self._asnumpy(self._recon_BF_subpixel_aligned), + ) + + if hasattr(self, "aberration_dict"): + self.metadata = Metadata( + name="aberrations_metadata", + data={ + v["aberration name"]: v["value [Ang]"] + for k, v in self.aberration_dict.items() + }, + ) + + self.metadata = Metadata( + name="reconstruction_metadata", + data=recon_metadata, + ) + + self._aligned_BF_emd = Array( + name="aligned_BF", + data=self._asnumpy(self._recon_BF), + ) + + # datacube + if self._save_datacube: + self.metadata = self._datacube.calibration + Custom.to_h5(self, group) + else: + dc = self._datacube + self._datacube = None + Custom.to_h5(self, group) + self._datacube = dc @classmethod def _get_constructor_args(cls, group): @@ -96,14 +186,68 @@ def _get_constructor_args(cls, group): Returns a dictionary of arguments/values to pass to the class' __init__ function """ - raise NotImplementedError() + # Get data + dict_data = cls._get_emd_attr_data(cls, group) + + # Get metadata dictionaries + instance_md = _read_metadata(group, "instantiation_metadata") + + # Fix calibrations bug + if "_datacube" in dict_data: + calibrations_dict = _read_metadata(group, "calibration")._params + cal = Calibration() + cal._params.update(calibrations_dict) + dc = dict_data["_datacube"] + dc.calibration = cal + else: + dc = None + + # Populate args and return + kwargs = { + "datacube": dc, + "energy": instance_md["energy"], + "object_padding_px": instance_md["object_padding_px"], + "name": instance_md["name"], + "verbose": True, # for compatibility + "device": "cpu", # for compatibility + } + + return kwargs def _populate_instance(self, group): """ Sets post-initialization properties, notably some preprocessing meta optional; during read, this method is run after object instantiation. """ - raise NotImplementedError() + + xp = self._xp + + # Preprocess metadata + preprocess_md = _read_metadata(group, "preprocess_metadata") + self._scan_sampling = preprocess_md["scan_sampling"] + self._wavelength = preprocess_md["wavelength"] + + # Reconstruction metadata + reconstruction_md = _read_metadata(group, "reconstruction_metadata") + self._recon_error = reconstruction_md["reconstruction_error"] + + # Data + dict_data = Custom._get_emd_attr_data(Custom, group) + + if "aberration_C1" in reconstruction_md.keys: + self.rotation_Q_to_R_rads = reconstruction_md["aberration_rotation_QR"] + self.transpose_detected = reconstruction_md["aberration_transpose"] + self.aberration_C1 = reconstruction_md["aberration_C1"] + self.aberration_A1x = reconstruction_md["aberration_A1x"] + self.aberration_A1y = reconstruction_md["aberration_A1y"] + + if "kde_upsample_factor" in reconstruction_md.keys: + self._kde_upsample_factor = reconstruction_md["kde_upsample_factor"] + self._recon_BF_subpixel_aligned = xp.asarray( + dict_data["_subpixel_aligned_BF_emd"].data, dtype=xp.float32 + ) + + self._recon_BF = xp.asarray(dict_data["_aligned_BF_emd"].data, dtype=xp.float32) def preprocess( self, @@ -111,6 +255,7 @@ def preprocess( threshold_intensity: float = 0.8, normalize_images: bool = True, normalize_order=0, + descan_correct: bool = True, defocus_guess: float = None, rotation_guess: float = None, plot_average_bf: bool = True, @@ -133,6 +278,8 @@ def preprocess( defocus_guess: float, optional Initial guess of defocus value (defocus dF) in A If None, first iteration is assumed to be in-focus + descan_correct: float, optional + If True, aligns bright field stack based on measured descan rotation_guess: float, optional Initial guess of defocus value in degrees If None, first iteration assumed to be 0 @@ -171,7 +318,10 @@ def preprocess( self._datacube, require_calibrations=True, ) - self._intensities = xp.asarray(self._intensities, dtype=xp.float32) + + self._region_of_interest_shape = np.array(self._intensities.shape[-2:]) + self._scan_shape = np.array(self._intensities.shape[:2]) + # make sure mean diffraction pattern is shaped correctly if (self._dp_mean.shape[0] != self._intensities.shape[2]) or ( self._dp_mean.shape[1] != self._intensities.shape[3] @@ -180,6 +330,45 @@ def preprocess( "dp_mean must match the datacube shape. Try setting dp_mean = None." ) + # descan correct + if descan_correct: + ( + _, + _, + com_fitted_x, + com_fitted_y, + _, + _, + ) = self._calculate_intensities_center_of_mass( + self._intensities, + dp_mask=None, + fit_function="plane", + com_shifts=None, + com_measured=None, + ) + + com_fitted_x = asnumpy(com_fitted_x) + com_fitted_y = asnumpy(com_fitted_y) + intensities = asnumpy(self._intensities) + intensities_shifted = np.zeros_like(intensities) + + center_x, center_y = self._region_of_interest_shape / 2 + + for rx in range(intensities_shifted.shape[0]): + for ry in range(intensities_shifted.shape[1]): + intensity_shifted = get_shifted_ar( + intensities[rx, ry], + -com_fitted_x[rx, ry] + center_x, + -com_fitted_y[rx, ry] + center_y, + bilinear=True, + device="cpu", + ) + + intensities_shifted[rx, ry] = intensity_shifted + + self._intensities = xp.asarray(intensities_shifted, xp.float32) + self._dp_mean = self._intensities.mean((0, 1)) + # select virtual detector pixels self._dp_mask = self._dp_mean >= (xp.max(self._dp_mean) * threshold_intensity) self._num_bf_images = int(xp.count_nonzero(self._dp_mask)) @@ -187,14 +376,16 @@ def preprocess( # diffraction space coordinates self._xy_inds = np.argwhere(self._dp_mask) - self._kxy = (self._xy_inds - xp.mean(self._xy_inds, axis=0)[None]) * xp.array( - self._reciprocal_sampling - )[None] + self._kxy = xp.asarray( + (self._xy_inds - xp.mean(self._xy_inds, axis=0)[None]) + * xp.array(self._reciprocal_sampling)[None], + dtype=xp.float32, + ) self._probe_angles = self._kxy * self._wavelength self._kr = xp.sqrt(xp.sum(self._kxy**2, axis=1)) # Window function - x = xp.linspace(-1, 1, self._grid_scan_shape[0] + 1)[1:] + x = xp.linspace(-1, 1, self._grid_scan_shape[0] + 1, dtype=xp.float32)[1:] x -= (x[1] - x[0]) / 2 wx = ( xp.sin( @@ -205,7 +396,7 @@ def preprocess( ) ** 2 ) - y = xp.linspace(-1, 1, self._grid_scan_shape[1] + 1)[1:] + y = xp.linspace(-1, 1, self._grid_scan_shape[1] + 1, dtype=xp.float32)[1:] y -= (y[1] - y[0]) / 2 wy = ( xp.sin( @@ -222,7 +413,8 @@ def preprocess( ( self._grid_scan_shape[0] + self._object_padding_px[0], self._grid_scan_shape[1] + self._object_padding_px[1], - ) + ), + dtype=xp.float32, ) self._window_pad[ self._object_padding_px[0] // 2 : self._grid_scan_shape[0] @@ -245,7 +437,8 @@ def preprocess( self._grid_scan_shape[1] + self._object_padding_px[1], ) if normalize_images: - self._stack_BF = xp.ones(stack_shape) + self._stack_BF = xp.ones(stack_shape, dtype=xp.float32) + self._stack_BF_no_window = xp.ones(stack_shape, xp.float32) if normalize_order == 0: all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] @@ -259,13 +452,21 @@ def preprocess( self._window_inv[None] + self._window_edge[None] * all_bfs ) + self._stack_BF_no_window[ + :, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs + elif normalize_order == 1: - x = xp.linspace(-0.5, 0.5, all_bfs.shape[1]) - y = xp.linspace(-0.5, 0.5, all_bfs.shape[2]) + x = xp.linspace(-0.5, 0.5, all_bfs.shape[1], xp.float32) + y = xp.linspace(-0.5, 0.5, all_bfs.shape[2], xp.float32) ya, xa = xp.meshgrid(y, x) basis = np.vstack( ( - xp.ones(xa.size), + xp.ones_like(xa), xa.ravel(), ya.ravel(), ) @@ -285,9 +486,18 @@ def preprocess( basis @ coefs[0], all_bfs.shape[1:3] ) + self._stack_BF_no_window[ + a0, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs[a0] / xp.reshape(basis @ coefs[0], all_bfs.shape[1:3]) + else: all_means = xp.mean(all_bfs, axis=(1, 2)) self._stack_BF = xp.full(stack_shape, all_means[:, None, None]) + self._stack_BF_no_window = xp.full(stack_shape, all_means[:, None, None]) self._stack_BF[ :, self._object_padding_px[0] // 2 : self._grid_scan_shape[0] @@ -299,9 +509,21 @@ def preprocess( + self._window_edge[None] * all_bfs ) + self._stack_BF_no_window[ + :, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs + # Fourier space operators for image shifts qx = xp.fft.fftfreq(self._stack_BF.shape[1], d=1) + qx = xp.asarray(qx, dtype=xp.float32) + qy = xp.fft.fftfreq(self._stack_BF.shape[2], d=1) + qy = xp.asarray(qy, dtype=xp.float32) + qxa, qya = xp.meshgrid(qx, qy, indexing="ij") self._qx_shift = -2j * xp.pi * qxa self._qy_shift = -2j * xp.pi * qya @@ -336,7 +558,7 @@ def preprocess( del Gs else: - self._xy_shifts = xp.zeros((self._num_bf_images, 2)) + self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) self._stack_mean = xp.mean(self._stack_BF) self._mask_sum = xp.sum(self._window_edge) * self._num_bf_images @@ -506,8 +728,6 @@ def tune_angle_and_defocus( convergence.append(asnumpy(self._recon_error[0])) if plot_convergence: - from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable - fig, ax = plt.subplots() ax.set_title("convergence") im = ax.imshow( @@ -533,9 +753,9 @@ def tune_angle_and_defocus( divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im, cax=cax) + fig.colorbar(im, cax=cax) - plt.tight_layout() + fig.tight_layout() if return_values: convergence = np.array(convergence).reshape( @@ -548,7 +768,7 @@ def reconstruct( max_alignment_bin: int = None, min_alignment_bin: int = 1, max_iter_at_min_bin: int = 2, - upsample_factor: int = 8, + cross_correlation_upsample_factor: int = 8, regularizer_matrix_size: Tuple[int, int] = (1, 1), regularize_shifts: bool = True, running_average: bool = True, @@ -570,7 +790,7 @@ def reconstruct( Minimum bin size for bright field alignment max_iter_at_min_bin: int, optional Number of iterations to run at the smallest bin size - upsample_factor: int, optional + cross_correlation_upsample_factor: int, optional DFT upsample factor for subpixel alignment regularizer_matrix_size: Tuple[int,int], optional Bernstein basis degree used for regularizing shifts @@ -623,7 +843,8 @@ def reconstruct( ( self._num_bf_images, (regularizer_matrix_size[0] + 1) * (regularizer_matrix_size[1] + 1), - ) + ), + dtype=xp.float32, ) for ii in np.arange(regularizer_matrix_size[0] + 1): Bi = ( @@ -708,7 +929,7 @@ def reconstruct( # Sort by radial order, from center to outer edge inds_order = xp.argsort(xp.sum(xy_vals**2, axis=1)) - shifts_update = xp.zeros((self._num_bf_images, 2)) + shifts_update = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) for a1 in tqdmnd( xy_vals.shape[0], @@ -730,7 +951,7 @@ def reconstruct( xy_shift = align_images_fourier( G_ref, G, - upsample_factor=upsample_factor, + upsample_factor=cross_correlation_upsample_factor, device=self._device, ) @@ -777,11 +998,19 @@ def reconstruct( self._qx_shift[None] * dx[:, None, None] + self._qy_shift[None] * dy[:, None, None] ) + self._stack_BF = xp.real(xp.fft.ifft2(Gs * shift_op)) self._stack_mask = xp.real( xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op) ) + self._stack_BF = xp.asarray( + self._stack_BF, dtype=xp.float32 + ) # numpy fft upcasts? + self._stack_mask = xp.asarray( + self._stack_mask, dtype=xp.float32 + ) # numpy fft upcasts? + del Gs # Center the shifts @@ -837,31 +1066,293 @@ def reconstruct( return self + def subpixel_alignment( + self, + kde_upsample_factor=None, + kde_sigma=0.125, + plot_upsampled_BF_comparison: bool = True, + plot_upsampled_FFT_comparison: bool = False, + **kwargs, + ): + """ + Upsample and subpixel-align BFs using the measured image shifts. + Uses kernel density estimation (KDE) to align upsampled BFs. + + Parameters + ---------- + kde_upsample_factor: int, optional + Real-space upsampling factor + kde_sigma: float, optional + KDE gaussian kernel bandwidth + plot_upsampled_BF_comparison: bool, optional + If True, the pre/post alignment BF images are plotted for comparison + plot_upsampled_FFT_comparison: bool, optional + If True, the pre/post alignment BF FFTs are plotted for comparison + + """ + xp = self._xp + asnumpy = self._asnumpy + gaussian_filter = self._gaussian_filter + + xy_shifts = self._xy_shifts + BF_size = np.array(self._stack_BF_no_window.shape[-2:]) + + self._DF_upsample_limit = np.max( + self._region_of_interest_shape / self._scan_shape + ) + self._BF_upsample_limit = ( + 2 * self._kr.max() / self._reciprocal_sampling[0] + ) / self._scan_shape.max() + if self._device == "gpu": + self._BF_upsample_limit = self._BF_upsample_limit.item() + + if kde_upsample_factor is None: + kde_upsample_factor = np.minimum( + self._BF_upsample_limit * 3 / 2, self._DF_upsample_limit + ) + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (1.5 times the " + f"bright-field upsampling limit of {self._BF_upsample_limit:.2f})." + ), + UserWarning, + ) + + if kde_upsample_factor < 1: + raise ValueError("kde_upsample_factor must be larger than 1") + + if kde_upsample_factor > self._DF_upsample_limit: + warnings.warn( + ( + "Requested upsampling factor exceeds " + f"dark-field upsampling limit of {self._DF_upsample_limit:.2f}." + ), + UserWarning, + ) + + self._kde_upsample_factor = kde_upsample_factor + pixel_output = np.round(BF_size * self._kde_upsample_factor).astype("int") + pixel_size = pixel_output.prod() + + # shifted coordinates + x = xp.arange(BF_size[0]) + y = xp.arange(BF_size[1]) + + xa, ya = xp.meshgrid(x, y, indexing="ij") + xa = ((xa + xy_shifts[:, 0, None, None]) * self._kde_upsample_factor).ravel() + ya = ((ya + xy_shifts[:, 1, None, None]) * self._kde_upsample_factor).ravel() + + # bilinear sampling + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + # resampling + inds_1D = xp.ravel_multi_index( + xp.hstack( + [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + ), + pixel_output, + mode=["wrap", "wrap"], + ) + + weights = xp.hstack( + ( + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ) + ) + + pix_count = xp.reshape( + xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output + ) + pix_output = xp.reshape( + xp.bincount( + inds_1D, + weights=weights * xp.tile(self._stack_BF_no_window.ravel(), 4), + minlength=pixel_size, + ), + pixel_output, + ) + + # kernel density estimate + sigma = kde_sigma * self._kde_upsample_factor + pix_count = gaussian_filter(pix_count, sigma) + pix_count[pix_output == 0.0] = np.inf + pix_output = gaussian_filter(pix_output, sigma) + pix_output /= pix_count + + self._recon_BF_subpixel_aligned = pix_output + self.recon_BF_subpixel_aligned = asnumpy(self._recon_BF_subpixel_aligned) + + # plotting + if plot_upsampled_BF_comparison: + if plot_upsampled_FFT_comparison: + figsize = kwargs.pop("figsize", (8, 8)) + fig, axs = plt.subplots(2, 2, figsize=figsize) + else: + figsize = kwargs.pop("figsize", (8, 4)) + fig, axs = plt.subplots(1, 2, figsize=figsize) + + axs = axs.flat + cmap = kwargs.pop("cmap", "magma") + + cropped_object = self._crop_padded_object(self._recon_BF) + cropped_object_aligned = self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) + + extent = [ + 0, + self._scan_sampling[1] * cropped_object.shape[1], + self._scan_sampling[0] * cropped_object.shape[0], + 0, + ] + + axs[0].imshow( + cropped_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[0].set_title("Aligned Bright Field") + + axs[1].imshow( + cropped_object_aligned, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[1].set_title("Upsampled Bright Field") + + for ax in axs[:2]: + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if plot_upsampled_FFT_comparison: + recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) + pad_x = np.round( + BF_size[0] * (self._kde_upsample_factor - 1) / 2 + ).astype("int") + pad_y = np.round( + BF_size[1] * (self._kde_upsample_factor - 1) / 2 + ).astype("int") + pad_recon_fft = asnumpy( + xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y))) + ) + + upsampled_fft = asnumpy( + xp.fft.fftshift( + xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + ) + ) + + reciprocal_extent = [ + -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + ] + + show( + pad_recon_fft, + figax=(fig, axs[2]), + extent=reciprocal_extent, + cmap="gray", + title="Aligned Bright Field FFT", + **kwargs, + ) + + show( + upsampled_fft, + figax=(fig, axs[3]), + extent=reciprocal_extent, + cmap="gray", + title="Upsampled Bright Field FFT", + **kwargs, + ) + + for ax in axs[2:]: + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") + ax.xaxis.set_ticks_position("bottom") + + fig.tight_layout() + def aberration_fit( self, - plot_CTF_compare: bool = False, - plot_dk: float = 0.005, - plot_k_sigma: float = 0.02, + fit_BF_shifts: bool = False, + fit_CTF_FFT: bool = False, + fit_aberrations_max_radial_order: int = 3, + fit_aberrations_max_angular_order: int = 4, + fit_aberrations_min_radial_order: int = 2, + fit_aberrations_min_angular_order: int = 0, + fit_max_thon_rings: int = 6, + fit_power_alpha: float = 2.0, + plot_CTF_comparison: bool = None, + plot_BF_shifts_comparison: bool = None, + upsampled: bool = True, + force_transpose: bool = None, ): """ Fit aberrations to the measured image shifts. Parameters ---------- - plot_CTF_compare: bool, optional - If True, the fitted CTF is plotted against the reconstructed frequencies - plot_dk: float, optional - Reciprocal bin-size for polar-averaged FFT - plot_k_sigma: float, optional - sigma to gaussian blur polar-averaged FFT by + fit_BF_shifts: bool + Set to True to fit aberrations to the measured BF shifts directly. + fit_CTF_FFT: bool + Set to True to fit aberrations in the FFT of the (upsampled) BF + image. Note that this method relies on visible zero crossings in the FFT. + fit_aberrations_max_radial_order: int + Max radial order for fitting of aberrations. + fit_aberrations_max_angular_order: int + Max angular order for fitting of aberrations. + fit_aberrations_min_radial_order: int + Min radial order for fitting of aberrations. + fit_aberrations_min_angular_order: int + Min angular order for fitting of aberrations. + fit_max_thon_rings: int + Max number of Thon rings to search for during CTF FFT fitting. + fit_power_alpha: int + Power to raise FFT alpha weighting during CTF FFT fitting. + plot_CTF_comparison: bool, optional + If True, the fitted CTF is plotted against the reconstructed frequencies. + plot_BF_shifts_comparison: bool, optional + If True, the measured vs fitted BF shifts are plotted. + upsampled: bool + If True, and upsampled BF is available, uses that for CTF FFT fitting. + force_transpose: bool + If True, and fit_BF_shifts is True, flips the measured x and y shifts """ xp = self._xp asnumpy = self._asnumpy - gaussian_filter = self._gaussian_filter + + ### First pass # Convert real space shifts to Angstroms - self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) + + if force_transpose is None: + self.transpose_detected = False + else: + self.transpose_detected = force_transpose + + if force_transpose is True: + self._xy_shifts_Ang = xp.flip(self._xy_shifts, axis=1) * xp.array( + self._scan_sampling + ) + else: + self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) # Solve affine transformation m = asnumpy( @@ -879,123 +1370,518 @@ def aberration_fit( ) m_aberration = -1.0 * m_aberration self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 - self.aberration_A1x = ( - m_aberration[0, 0] - m_aberration[1, 1] - ) / 2.0 # factor /2 for A1 astigmatism? /4? + self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + ### Second pass - # Print results - if self._verbose: - print( - ( - "Rotation of Q w.r.t. R = " - f"{np.rad2deg(self.rotation_Q_to_R_rads):.3f} deg" + # Aberration coefs + mn = [] + + for m in range( + fit_aberrations_min_radial_order - 1, fit_aberrations_max_radial_order + ): + n_max = np.minimum(fit_aberrations_max_angular_order, m + 1) + for n in range(fit_aberrations_min_angular_order, n_max + 1): + if (m + n) % 2: + mn.append([m, n, 0]) + if n > 0: + mn.append([m, n, 1]) + + self._aberrations_mn = np.array(mn) + self._aberrations_mn = self._aberrations_mn[ + np.argsort(self._aberrations_mn[:, 1]), : + ] + + sub = self._aberrations_mn[:, 1] > 0 + self._aberrations_mn[sub, :] = self._aberrations_mn[sub, :][ + np.argsort(self._aberrations_mn[sub, 0]), : + ] + self._aberrations_mn[~sub, :] = self._aberrations_mn[~sub, :][ + np.argsort(self._aberrations_mn[~sub, 0]), : + ] + self._aberrations_num = self._aberrations_mn.shape[0] + + if plot_CTF_comparison is None: + if fit_CTF_FFT: + plot_CTF_comparison = True + + if plot_BF_shifts_comparison is None: + if fit_BF_shifts: + plot_BF_shifts_comparison = True + + # Thon Rings Fitting + if fit_CTF_FFT or plot_CTF_comparison: + if upsampled and hasattr(self, "_kde_upsample_factor"): + im_FFT = xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + sx = self._scan_sampling[0] / self._kde_upsample_factor + sy = self._scan_sampling[1] / self._kde_upsample_factor + + else: + im_FFT = xp.abs(xp.fft.fft2(self._recon_BF)) + sx = self._scan_sampling[0] + sy = self._scan_sampling[1] + upsampled = False + + # FFT coordinates + qx = xp.fft.fftfreq(im_FFT.shape[0], sx) + qy = xp.fft.fftfreq(im_FFT.shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + alpha_FFT = xp.sqrt(qr2) * self._wavelength + theta_FFT = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + self._aberrations_basis_FFT = xp.zeros( + (alpha_FFT.size, self._aberrations_num) + ) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) / (m + 1) + ).ravel() + + elif a == 0: + # cos coef + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) * xp.cos(n * theta_FFT) / (m + 1) + ).ravel() + else: + # sin coef + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) * xp.sin(n * theta_FFT) / (m + 1) + ).ravel() + + # global scaling + self._aberrations_basis_FFT *= 2 * np.pi / self._wavelength + self._aberrations_surface_shape_FFT = alpha_FFT.shape + plot_mask = qr2 > np.pi**2 / 4 / np.abs(self.aberration_C1) + angular_mask = np.cos(8.0 * theta_FFT) ** 2 < 0.25 + + # CTF function + def calculate_CTF_FFT(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis_FFT[:, 0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis_FFT[:, a0] + return xp.reshape(chi, alpha_shape) + + # Direct Shifts Fitting + if fit_BF_shifts: + # FFT coordinates + sx = 1 / (self._reciprocal_sampling[0] * self._region_of_interest_shape[0]) + sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1]) + qx = xp.fft.fftfreq(self._region_of_interest_shape[0], sx) + qy = xp.fft.fftfreq(self._region_of_interest_shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + u = qx[:, None] * self._wavelength + v = qy[None, :] * self._wavelength + alpha = xp.sqrt(qr2) * self._wavelength + theta = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) + self._aberrations_basis_du = xp.zeros((alpha.size, self._aberrations_num)) + self._aberrations_basis_dv = xp.zeros((alpha.size, self._aberrations_num)) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + + if n == 0: + # Radially symmetric basis + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = (u * alpha ** (m - 1)).ravel() + self._aberrations_basis_dv[:, a0] = (v * alpha ** (m - 1)).ravel() + + elif a == 0: + # cos coef + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.cos(n * theta) + n * v * xp.sin(n * theta)) + / (m + 1) + ).ravel() + self._aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.cos(n * theta) - n * u * xp.sin(n * theta)) + / (m + 1) + ).ravel() + + else: + # sin coef + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.sin(n * theta) - n * v * xp.cos(n * theta)) + / (m + 1) + ).ravel() + self._aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.sin(n * theta) + n * u * xp.cos(n * theta)) + / (m + 1) + ).ravel() + + # global scaling + self._aberrations_basis *= 2 * np.pi / self._wavelength + self._aberrations_surface_shape = alpha.shape + + # CTF function + def calculate_CTF(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis[:, 0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis[:, a0] + return xp.reshape(chi, alpha_shape) + + # initial coefficients and plotting intensity range mask + self._aberrations_coefs = np.zeros(self._aberrations_num) + ind = np.argmin( + np.abs(self._aberrations_mn[:, 0] - 1.0) + self._aberrations_mn[:, 1] + ) + self._aberrations_coefs[ind] = self.aberration_C1 + + # Refinement using CTF fitting / Thon rings + if fit_CTF_FFT: + # scoring function to minimize - mean value of zero crossing regions of FFT + def score_CTF(coefs): + im_CTF = xp.abs( + calculate_CTF_FFT(self._aberrations_surface_shape_FFT, *coefs) + ) + mask = xp.logical_and( + im_CTF > 0.5 * np.pi, + im_CTF < (max_num_rings + 0.5) * np.pi, ) + if np.any(mask): + weights = xp.cos(im_CTF[mask]) ** 4 + return asnumpy( + xp.sum( + weights * im_FFT[mask] * alpha_FFT[mask] ** fit_power_alpha + ) + / xp.sum(weights) + ) + else: + return np.inf + + for max_num_rings in range(1, fit_max_thon_rings + 1): + # minimization + res = minimize( + score_CTF, + self._aberrations_coefs, + # method = 'Nelder-Mead', + # method = 'CG', + method="BFGS", + tol=1e-8, + ) + self._aberrations_coefs = res.x + + # Refinement using CTF fitting / Thon rings + elif fit_BF_shifts: + # Gradient basis + corner_indices = self._xy_inds - xp.asarray( + self._region_of_interest_shape // 2 ) - print( + raveled_indices = np.ravel_multi_index( + corner_indices.T, self._region_of_interest_shape, mode="wrap" + ) + gradients = xp.vstack( ( - "Astigmatism (A1x,A1y) = (" - f"{self.aberration_A1x:.0f}," - f"{self.aberration_A1y:.0f}) Ang" + self._aberrations_basis_du[raveled_indices, :], + self._aberrations_basis_dv[raveled_indices, :], ) ) - if self.aberration_C1 > 0: - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + + # (Relative) untransposed fit + tf = AffineTransform(angle=self.rotation_Q_to_R_rads) + rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() + aberrations_coefs, res = xp.linalg.lstsq( + gradients, rotated_shifts, rcond=None + )[:2] + + if force_transpose is None: + # (Relative) transposed fit + transposed_shifts = xp.flip(self._xy_shifts_Ang, axis=1) + m_T = asnumpy( + xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[ + 0 + ] + ) + m_rotation_T, _ = polar(m_T, side="right") + rotation_Q_to_R_rads_T = -1 * np.arctan2( + m_rotation_T[1, 0], m_rotation_T[0, 0] + ) + if np.abs( + np.mod(rotation_Q_to_R_rads_T + np.pi, 2.0 * np.pi) - np.pi + ) > (np.pi * 0.5): + rotation_Q_to_R_rads_T = ( + np.mod(rotation_Q_to_R_rads_T, 2.0 * np.pi) - np.pi + ) + + tf_T = AffineTransform(angle=rotation_Q_to_R_rads_T) + rotated_shifts_T = tf_T(transposed_shifts, xp=xp).T.ravel() + aberrations_coefs_T, res_T = xp.linalg.lstsq( + gradients, rotated_shifts_T, rcond=None + )[:2] + + # Compare fits + if res_T.sum() < res.sum(): + self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T + self.transpose_detected = not self.transpose_detected + self._aberrations_coefs = asnumpy(aberrations_coefs_T) + self._rotated_shifts = rotated_shifts_T + + warnings.warn( + ( + "Data transpose detected. " + f"Overwriting rotation value to {np.rad2deg(rotation_Q_to_R_rads_T):.3f}" + ), + UserWarning, + ) else: - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + self._aberrations_coefs = asnumpy(aberrations_coefs) + self._rotated_shifts = rotated_shifts # Plot the CTF comparison between experiment and fit - if plot_CTF_compare: - # Get polar mean from FFT of BF reconstruction - im_fft = xp.abs(xp.fft.fft2(self._recon_BF)) - - # coordinates - kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) - kra = xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2) - k_max = xp.max(kra) / np.sqrt(2.0) - k_num_bins = int(xp.ceil(k_max / plot_dk)) - k_bins = xp.arange(k_num_bins + 1) * plot_dk - - # histogram - k_ind = kra / plot_dk - kf = np.floor(k_ind).astype("int") - dk = k_ind - kf - sub = kf <= k_num_bins - hist_exp = xp.bincount( - kf[sub], weights=im_fft[sub] * (1 - dk[sub]), minlength=k_num_bins + if plot_CTF_comparison: + # Generate FFT plotting image + im_scale = asnumpy(im_FFT * alpha_FFT**fit_power_alpha) + int_vals = np.sort(im_scale.ravel()) + int_range = ( + int_vals[np.round(0.02 * im_scale.size).astype("int")], + int_vals[np.round(0.98 * im_scale.size).astype("int")], ) - hist_norm = xp.bincount( - kf[sub], weights=(1 - dk[sub]), minlength=k_num_bins + int_range = ( + int_range[0], + (int_range[1] - int_range[0]) * 1.0 + int_range[0], + ) + im_scale = np.clip( + (np.fft.fftshift(im_scale) - int_range[0]) + / (int_range[1] - int_range[0]), + 0, + 1, ) - sub = kf <= k_num_bins - 1 + im_plot = np.tile(im_scale[:, :, None], (1, 1, 3)) - hist_exp += xp.bincount( - kf[sub] + 1, weights=im_fft[sub] * (dk[sub]), minlength=k_num_bins + # Add CTF zero crossings + im_CTF = calculate_CTF_FFT( + self._aberrations_surface_shape_FFT, *self._aberrations_coefs ) - hist_norm += xp.bincount( - kf[sub] + 1, weights=(dk[sub]), minlength=k_num_bins + im_CTF_cos = xp.cos(xp.abs(im_CTF)) ** 4 + im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings + 0.5) * np.pi] = np.pi / 2 + im_CTF = xp.abs(xp.sin(im_CTF)) < 0.15 + im_CTF[xp.logical_not(plot_mask)] = 0 + + im_CTF = np.fft.fftshift(asnumpy(im_CTF * angular_mask)) + im_plot[:, :, 0] += im_CTF + im_plot[:, :, 1] -= im_CTF + im_plot[:, :, 2] -= im_CTF + im_plot = np.clip(im_plot, 0, 1) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) + ax1.imshow(im_plot, vmin=int_range[0], vmax=int_range[1]) + + ax2.imshow(np.fft.fftshift(asnumpy(im_CTF_cos)), cmap="gray") + + fig.tight_layout() + + # Plot the measured/fitted shifts comparison + if plot_BF_shifts_comparison: + if not fit_BF_shifts: + raise ValueError() + + measured_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 ) + measured_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._rotated_shifts[: self._xy_inds.shape[0]] - # KDE and normalizing - k_sigma = plot_dk / plot_k_sigma - hist_exp[0] = 0.0 - hist_exp = gaussian_filter(hist_exp, sigma=k_sigma, mode="nearest") - hist_norm = gaussian_filter(hist_norm, sigma=k_sigma, mode="nearest") - hist_exp /= hist_norm + measured_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + measured_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._rotated_shifts[self._xy_inds.shape[0] :] - # CTF comparison - CTF_fit = xp.sin( - (-np.pi * self._wavelength * self.aberration_C1) * k_bins**2 + fitted_shifts = xp.tensordot( + gradients, xp.array(self._aberrations_coefs), axes=1 ) - # plotting input - log scale - min_hist_val = xp.max(hist_exp) * 1e-3 - hist_plot = xp.log(np.maximum(hist_exp, min_hist_val)) - hist_plot -= xp.min(hist_plot) - hist_plot /= xp.max(hist_plot) + fitted_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = fitted_shifts[ + : self._xy_inds.shape[0] + ] - hist_plot = asnumpy(hist_plot) - k_bins = asnumpy(k_bins) - CTF_fit = asnumpy(CTF_fit) + fitted_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = fitted_shifts[ + self._xy_inds.shape[0] : + ] - fig, ax = plt.subplots(figsize=(8, 4)) + max_shift = xp.max( + xp.array( + [ + xp.abs(measured_shifts_sx).max(), + xp.abs(measured_shifts_sy).max(), + xp.abs(fitted_shifts_sx).max(), + xp.abs(fitted_shifts_sy).max(), + ] + ) + ) - ax.fill_between( - k_bins, - hist_plot, - color=(0.7, 0.7, 0.7, 1), + show( + [ + [asnumpy(measured_shifts_sx), asnumpy(measured_shifts_sy)], + [asnumpy(fitted_shifts_sx), asnumpy(fitted_shifts_sy)], + ], + cmap="PiYG", + vmin=-max_shift, + vmax=max_shift, + intensity_range="absolute", + axsize=(4, 4), + ticks=False, + title=[ + "Measured Vertical Shifts", + "Measured Horizontal Shifts", + "Fitted Vertical Shifts", + "Fitted Horizontal Shifts", + ], ) - ax.plot( - k_bins, - np.clip(CTF_fit, 0.0, np.inf), - color=(1, 0, 0, 1), - linewidth=2, + self.aberration_dict = { + tuple(self._aberrations_mn[a0]): { + "aberration name": _aberration_names.get( + tuple(self._aberrations_mn[a0, :2]), "-" + ).strip(), + "value [Ang]": self._aberrations_coefs[a0], + } + for a0 in range(self._aberrations_num) + } + + # Print results + if self._verbose: + if fit_CTF_FFT or fit_BF_shifts: + print("Initial Aberration coefficients") + print("-------------------------------") + print( + ( + "Rotation of Q w.r.t. R = " + f"{np.rad2deg(self.rotation_Q_to_R_rads):.3f} deg" + ) ) - ax.plot( - k_bins, - np.clip(-CTF_fit, 0.0, np.inf), - color=(0, 0.5, 1, 1), - linewidth=2, + print( + ( + "Astigmatism (A1x,A1y) = (" + f"{self.aberration_A1x:.0f}," + f"{self.aberration_A1y:.0f}) Ang" + ) ) - ax.set_xlim([0, k_bins[-1]]) - ax.set_ylim([0, 1.05]) + print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") + print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + + if fit_CTF_FFT or fit_BF_shifts: + print() + print("Refined Aberration coefficients") + print("-------------------------------") + print("aberration radial angular dir. coefs") + print("name order order Ang ") + print("---------- ------- ------- ---- -----") + + for a0 in range(self._aberrations_mn.shape[0]): + m, n, a = self._aberrations_mn[a0] + name = _aberration_names.get((m, n), " -- ") + if n == 0: + print( + name + + " " + + str(m + 1) + + " 0 - " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) + elif a == 0: + print( + name + + " " + + str(m + 1) + + " " + + str(n) + + " x " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) + else: + print( + name + + " " + + str(m + 1) + + " " + + str(n) + + " y " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + def _calculate_CTF(self, alpha_shape, sampling, *coefs): + xp = self._xp + + # FFT coordinates + sx, sy = sampling + qx = xp.fft.fftfreq(alpha_shape[0], sx) + qy = xp.fft.fftfreq(alpha_shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + alpha = xp.sqrt(qr2) * self._wavelength + theta = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + + elif a == 0: + # cos coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + else: + # sin coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + + # global scaling + aberrations_basis *= 2 * np.pi / self._wavelength + + chi = xp.zeros_like(aberrations_basis[:, 0]) + + for a0 in range(len(coefs)): + chi += coefs[a0] * aberrations_basis[:, a0] + + return xp.reshape(chi, alpha_shape) def aberration_correct( self, + use_CTF_fit=None, plot_corrected_phase: bool = True, k_info_limit: float = None, k_info_power: float = 1.0, Wiener_filter=False, - Wiener_signal_noise_ratio=1.0, - Wiener_filter_low_only=False, + Wiener_signal_noise_ratio: float = 1.0, + Wiener_filter_low_only: bool = False, + upsampled: bool = True, **kwargs, ): """ @@ -1003,6 +1889,9 @@ def aberration_correct( Parameters ---------- + use_FFT_fit: bool + Use the CTF fitted to the zero crossings of the FFT. + Default is True plot_corrected_phase: bool, optional If True, the CTF-corrected phase is plotted k_info_limit: float, optional @@ -1028,46 +1917,79 @@ def aberration_correct( ) ) + if upsampled and hasattr(self, "_kde_upsample_factor"): + im = self._recon_BF_subpixel_aligned + sx = self._scan_sampling[0] / self._kde_upsample_factor + sy = self._scan_sampling[1] / self._kde_upsample_factor + else: + upsampled = False + im = self._recon_BF + sx = self._scan_sampling[0] + sy = self._scan_sampling[1] + # Fourier coordinates - kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) + kx = xp.fft.fftfreq(im.shape[0], sx) + ky = xp.fft.fftfreq(im.shape[1], sy) kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 - # CTF - sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + if use_CTF_fit is None: + if hasattr(self, "_aberrations_surface_shape"): + use_CTF_fit = True - if Wiener_filter: - SNR_inv = ( - xp.sqrt( - 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) - ) - / Wiener_signal_noise_ratio + if use_CTF_fit: + sin_chi = np.sin( + self._calculate_CTF(im.shape, (sx, sy), *self._aberrations_coefs) ) - CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) - if Wiener_filter_low_only: - # limit Wiener filter to only the part of the CTF before 1st maxima - k_thresh = 1 / xp.sqrt( - 2.0 * self._wavelength * xp.abs(self.aberration_C1) - ) - k_mask = kra2 >= k_thresh**2 - CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) - # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF) * CTF_corr - - else: - # CTF without tilt correction (beyond the parallax operator) CTF_corr = xp.sign(sin_chi) CTF_corr[0, 0] = 0 # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF) * CTF_corr + im_fft_corr = xp.fft.fft2(im) * CTF_corr # if needed, add low pass filter output image if k_info_limit is not None: im_fft_corr /= 1 + (kra2**k_info_power) / ( (k_info_limit) ** (2 * k_info_power) ) + else: + # CTF + sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + + if Wiener_filter: + SNR_inv = ( + xp.sqrt( + 1 + + (kra2**k_info_power) + / ((k_info_limit) ** (2 * k_info_power)) + ) + / Wiener_signal_noise_ratio + ) + CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) + if Wiener_filter_low_only: + # limit Wiener filter to only the part of the CTF before 1st maxima + k_thresh = 1 / xp.sqrt( + 2.0 * self._wavelength * xp.abs(self.aberration_C1) + ) + k_mask = kra2 >= k_thresh**2 + CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(im) * CTF_corr + + else: + # CTF without tilt correction (beyond the parallax operator) + CTF_corr = xp.sign(sin_chi) + CTF_corr[0, 0] = 0 + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(im) * CTF_corr + + # if needed, add low pass filter output image + if k_info_limit is not None: + im_fft_corr /= 1 + (kra2**k_info_power) / ( + (k_info_limit) ** (2 * k_info_power) + ) # Output phase image self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr)) @@ -1084,12 +2006,14 @@ def aberration_correct( fig, ax = plt.subplots(figsize=figsize) - cropped_object = self._crop_padded_object(self._recon_phase_corrected) + cropped_object = self._crop_padded_object( + self._recon_phase_corrected, upsampled=upsampled + ) extent = [ 0, - self._scan_sampling[1] * cropped_object.shape[1], - self._scan_sampling[0] * cropped_object.shape[0], + sy * cropped_object.shape[1], + sx * cropped_object.shape[0], 0, ] @@ -1246,6 +2170,7 @@ def _crop_padded_object( self, padded_object: np.ndarray, remaining_padding: int = 0, + upsampled: bool = False, ): """ Utility function to crop padded object @@ -1266,8 +2191,19 @@ def _crop_padded_object( asnumpy = self._asnumpy - pad_x = self._object_padding_px[0] // 2 - remaining_padding - pad_y = self._object_padding_px[1] // 2 - remaining_padding + if upsampled: + pad_x = np.round( + self._object_padding_px[0] / 2 * self._kde_upsample_factor + ).astype("int") + pad_y = np.round( + self._object_padding_px[1] / 2 * self._kde_upsample_factor + ).astype("int") + else: + pad_x = self._object_padding_px[0] // 2 + pad_y = self._object_padding_px[1] // 2 + + pad_x -= remaining_padding + pad_y -= remaining_padding return asnumpy(padded_object[pad_x:-pad_x, pad_y:-pad_y]) @@ -1276,6 +2212,7 @@ def _visualize_figax( fig, ax, remaining_padding: int = 0, + upsampled: bool = False, **kwargs, ): """ @@ -1294,14 +2231,31 @@ def _visualize_figax( cmap = kwargs.pop("cmap", "magma") - cropped_object = self._crop_padded_object(self._recon_BF, remaining_padding) + if upsampled: + cropped_object = self._crop_padded_object( + self._recon_BF_subpixel_aligned, remaining_padding, upsampled + ) - extent = [ - 0, - self._scan_sampling[1] * cropped_object.shape[1], - self._scan_sampling[0] * cropped_object.shape[0], - 0, - ] + extent = [ + 0, + self._scan_sampling[1] + * cropped_object.shape[1] + / self._kde_upsample_factor, + self._scan_sampling[0] + * cropped_object.shape[0] + / self._kde_upsample_factor, + 0, + ] + + else: + cropped_object = self._crop_padded_object(self._recon_BF, remaining_padding) + + extent = [ + 0, + self._scan_sampling[1] * cropped_object.shape[1], + self._scan_sampling[0] * cropped_object.shape[0], + 0, + ] ax.imshow( cropped_object, @@ -1310,7 +2264,7 @@ def _visualize_figax( **kwargs, ) - def _visualize_shifts( + def show_shifts( self, scale_arrows=1, plot_arrow_freq=1, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 67dba6115..3eebdb068 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -1,4 +1,7 @@ +import warnings + import numpy as np +import pylops from py4DSTEM.process.phase.utils import ( array_slice, estimate_global_transformation_ransac, @@ -183,6 +186,63 @@ def _object_butterworth_constraint( return current_object + def _object_denoise_tv_pylops(self, current_object, weight, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weight : float + Denoising weight. The greater `weight`, the more denoising (at + the expense of fidelity to `input`). + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny) + xy_laplacian = pylops.Laplacian( + (nx, ny), axes=(0, 1), edge=False, kind="backward" + ) + + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weight], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + current_object_tv = current_object_tv.reshape(current_object.shape) + + return current_object_tv + def _object_denoise_tv_chambolle( self, current_object, @@ -229,90 +289,100 @@ def _object_denoise_tv_chambolle( Adapted skimage.restoration.denoise_tv_chambolle. """ xp = self._xp - - current_object_sum = xp.sum(current_object) - if axis is None: - ndim = xp.arange(current_object.ndim).tolist() - elif isinstance(axis, tuple): - ndim = list(axis) + if xp.iscomplexobj(current_object): + updated_object = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) else: - ndim = [axis] - - if pad_object: - pad_width = ((0, 0),) * current_object.ndim - pad_width = list(pad_width) - for ax in range(len(ndim)): - pad_width[ndim[ax]] = (1, 1) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" + current_object_sum = xp.sum(current_object) + if axis is None: + ndim = xp.arange(current_object.ndim).tolist() + elif isinstance(axis, tuple): + ndim = list(axis) + else: + ndim = [axis] + + if pad_object: + pad_width = ((0, 0),) * current_object.ndim + pad_width = list(pad_width) + for ax in range(len(ndim)): + pad_width[ndim[ax]] = (1, 1) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + p = xp.zeros( + (current_object.ndim,) + current_object.shape, + dtype=current_object.dtype, ) + g = xp.zeros_like(p) + d = xp.zeros_like(current_object) + + i = 0 + while i < max_num_iter: + if i > 0: + # d will be the (negative) divergence of p + d = -p.sum(0) + slices_d = [ + slice(None), + ] * current_object.ndim + slices_p = [ + slice(None), + ] * (current_object.ndim + 1) + for ax in range(len(ndim)): + slices_d[ndim[ax]] = slice(1, None) + slices_p[ndim[ax] + 1] = slice(0, -1) + slices_p[0] = ndim[ax] + d[tuple(slices_d)] += p[tuple(slices_p)] + slices_d[ndim[ax]] = slice(None) + slices_p[ndim[ax] + 1] = slice(None) + updated_object = current_object + d + else: + updated_object = current_object + E = (d**2).sum() - p = xp.zeros( - (current_object.ndim,) + current_object.shape, dtype=current_object.dtype - ) - g = xp.zeros_like(p) - d = xp.zeros_like(current_object) - - i = 0 - while i < max_num_iter: - if i > 0: - # d will be the (negative) divergence of p - d = -p.sum(0) - slices_d = [ - slice(None), - ] * current_object.ndim - slices_p = [ + # g stores the gradients of updated_object along each axis + # e.g. g[0] is the first order finite difference along axis 0 + slices_g = [ slice(None), ] * (current_object.ndim + 1) for ax in range(len(ndim)): - slices_d[ndim[ax]] = slice(1, None) - slices_p[ndim[ax] + 1] = slice(0, -1) - slices_p[0] = ndim[ax] - d[tuple(slices_d)] += p[tuple(slices_p)] - slices_d[ndim[ax]] = slice(None) - slices_p[ndim[ax] + 1] = slice(None) - updated_object = current_object + d - else: - updated_object = current_object - E = (d**2).sum() - - # g stores the gradients of updated_object along each axis - # e.g. g[0] is the first order finite difference along axis 0 - slices_g = [ - slice(None), - ] * (current_object.ndim + 1) - for ax in range(len(ndim)): - slices_g[ndim[ax] + 1] = slice(0, -1) - slices_g[0] = ndim[ax] - g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) - slices_g[ndim[ax] + 1] = slice(None) - if scaling is not None: - scaling /= xp.max(scaling) - g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] - norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] - E += weight * norm.sum() - tau = 1.0 / (2.0 * len(ndim)) - norm *= tau / weight - norm += 1.0 - p -= tau * g - p /= norm - E /= float(current_object.size) - if i == 0: - E_init = E - E_previous = E - else: - if xp.abs(E_previous - E) < eps * E_init: - break - else: + slices_g[ndim[ax] + 1] = slice(0, -1) + slices_g[0] = ndim[ax] + g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) + slices_g[ndim[ax] + 1] = slice(None) + if scaling is not None: + scaling /= xp.max(scaling) + g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] + norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] + E += weight * norm.sum() + tau = 1.0 / (2.0 * len(ndim)) + norm *= tau / weight + norm += 1.0 + p -= tau * g + p /= norm + E /= float(current_object.size) + if i == 0: + E_init = E E_previous = E - i += 1 + else: + if xp.abs(E_previous - E) < eps * E_init: + break + else: + E_previous = E + i += 1 - if pad_object: - for ax in range(len(ndim)): - slices = array_slice(ndim[ax], current_object.ndim, 1, -1) - updated_object = updated_object[slices] + if pad_object: + for ax in range(len(ndim)): + slices = array_slice(ndim[ax], current_object.ndim, 1, -1) + updated_object = updated_object[slices] + updated_object = ( + updated_object / xp.sum(updated_object) * current_object_sum + ) - return updated_object / xp.sum(updated_object) * current_object_sum + return updated_object def _probe_center_of_mass_constraint(self, current_probe): """ @@ -363,7 +433,7 @@ def _probe_amplitude_constraint( xp = self._xp erf = self._erf - probe_intensity = xp.abs(current_probe) ** 2 + # probe_intensity = xp.abs(current_probe) ** 2 # current_probe_sum = xp.sum(probe_intensity) X = xp.fft.fftfreq(current_probe.shape[0])[:, None] @@ -485,10 +555,12 @@ def _probe_aberration_fitting_constraint( fourier_probe = xp.fft.fft2(current_probe) fourier_probe_abs = xp.abs(fourier_probe) sampling = self.sampling + energy = self._energy fitted_angle, _ = fit_aberration_surface( fourier_probe, sampling, + energy, max_angular_order, max_radial_order, xp=xp, diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index e3713cde1..37438852f 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -192,6 +192,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -246,6 +247,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -401,9 +404,7 @@ def preprocess( amplitudes_0, mean_diffraction_intensity_0, ) = self._normalize_diffraction_intensities( - intensities_0, - com_fitted_x_0, - com_fitted_y_0, + intensities_0, com_fitted_x_0, com_fitted_y_0, crop_patterns ) # explicitly delete namescapes @@ -484,9 +485,7 @@ def preprocess( amplitudes_1, mean_diffraction_intensity_1, ) = self._normalize_diffraction_intensities( - intensities_1, - com_fitted_x_1, - com_fitted_y_1, + intensities_1, com_fitted_x_1, com_fitted_y_1, crop_patterns ) # explicitly delete namescapes @@ -568,9 +567,7 @@ def preprocess( amplitudes_2, mean_diffraction_intensity_2, ) = self._normalize_diffraction_intensities( - intensities_2, - com_fitted_x_2, - com_fitted_y_2, + intensities_2, com_fitted_x_2, com_fitted_y_2, crop_patterns ) # explicitly delete namescapes @@ -683,6 +680,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -746,19 +747,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -780,23 +775,22 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax2.scatter( self.positions[:, 1], @@ -808,7 +802,7 @@ def preprocess( ax2.set_xlabel("y [A]") ax2.set_xlim((extent[0], extent[1])) ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object Field of View") + ax2.set_title("Object field of view") fig.tight_layout() @@ -2232,6 +2226,9 @@ def _constraints( q_highpass_e, q_highpass_m, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, warmup_iteration, object_positivity, shrinkage_rad, @@ -2300,6 +2297,12 @@ def _constraints( Cut-off frequency in A^-1 for high-pass filtering magnetic object butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising warmup_iteration: bool If True, constraints electrostatic object only object_positivity: bool @@ -2349,6 +2352,15 @@ def _constraints( if self._object_type == "complex": magnetic_obj = magnetic_obj.real + if tv_denoise: + electrostatic_obj = self._object_denoise_tv_pylops( + electrostatic_obj, tv_denoise_weight, tv_denoise_inner_iter + ) + + if not warmup_iteration: + magnetic_obj = self._object_denoise_tv_pylops( + magnetic_obj, tv_denoise_weight, tv_denoise_inner_iter + ) if shrinkage_rad > 0.0 or object_mask is not None: electrostatic_obj = self._object_shrinkage_constraint( @@ -2446,6 +2458,9 @@ def reconstruct( q_highpass_e: float = None, q_highpass_m: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -2538,6 +2553,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass filtering magnetic object butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -2748,6 +2769,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = (None,) * self._num_sim_measurements self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2899,6 +2922,9 @@ def reconstruct( q_highpass_e=q_highpass_e, q_highpass_m=q_highpass_m, butterworth_order=butterworth_order, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -3029,8 +3055,6 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (12, 5)) cmap_e = kwargs.pop("cmap_e", "magma") cmap_m = kwargs.pop("cmap_m", "PuOr") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) if self._object_type == "complex": obj_e = np.angle(self.object[0]) @@ -3052,6 +3076,11 @@ def _visualize_last_iteration( vmin_m = kwargs.pop("vmin_m", min_m) vmax_m = kwargs.pop("vmax_m", max_m) + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + extent = [ 0, self.sampling[1] * rotated_shape[1], @@ -3156,29 +3185,29 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 2]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, power=2, chroma_boost=chroma_boost ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: # Electrostatic Object @@ -3229,10 +3258,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index df0ef5e1c..5dd19d7bd 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -188,6 +188,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -245,6 +246,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -330,9 +333,7 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns ) # explicitly delete namespace @@ -412,6 +413,11 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) + self._probe = ( ComplexProbe( gpts=self._region_of_interest_shape, @@ -474,19 +480,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -508,23 +508,19 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax1, chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="gray", ) ax2.scatter( self.positions[:, 1], @@ -536,7 +532,7 @@ def preprocess( ax2.set_xlabel("y [A]") ax2.set_xlim((extent[0], extent[1])) ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object Field of View") + ax2.set_title("Object field of view") fig.tight_layout() @@ -1023,6 +1019,9 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, object_positivity, shrinkage_rad, object_mask, @@ -1078,6 +1077,12 @@ def _constraints( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool If True, clips negative potential values shrinkage_rad: float @@ -1108,6 +1113,11 @@ def _constraints( butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, tv_denoise_weight, tv_denoise_inner_iter + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( current_object, @@ -1198,6 +1208,9 @@ def reconstruct( q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1284,6 +1297,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -1486,6 +1505,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -1618,6 +1639,9 @@ def reconstruct( q_lowpass=q_lowpass, q_highpass=q_highpass, butterworth_order=butterworth_order, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -1734,8 +1758,11 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -1828,29 +1855,31 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, + power=2, + chroma_boost=chroma_boost, ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -1883,10 +1912,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -1957,9 +1986,12 @@ def _visualize_all_iterations( else (3 * iterations_grid[1], 3 * iterations_grid[0]) ) figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "inferno") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + cmap = kwargs.pop("cmap", "magma") + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2063,8 +2095,7 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2072,21 +2103,23 @@ def _visualize_all_iterations( else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2098,7 +2131,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index c2e1d3b77..d29765d04 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1543,39 +1543,77 @@ def step_model(radius, sig_0, rad_0, width): def aberrations_basis_function( probe_size, probe_sampling, + energy, max_angular_order, max_radial_order, xp=np, ): """ """ + + # Add constant phase shift in basis + mn = [[-1, 0, 0]] + + for m in range(1, max_radial_order): + n_max = np.minimum(max_angular_order, m + 1) + for n in range(0, n_max + 1): + if (m + n) % 2: + mn.append([m, n, 0]) + if n > 0: + mn.append([m, n, 1]) + + aberrations_mn = np.array(mn) + aberrations_mn = aberrations_mn[np.argsort(aberrations_mn[:, 1]), :] + + sub = aberrations_mn[:, 1] > 0 + aberrations_mn[sub, :] = aberrations_mn[sub, :][ + np.argsort(aberrations_mn[sub, 0]), : + ] + aberrations_mn[~sub, :] = aberrations_mn[~sub, :][ + np.argsort(aberrations_mn[~sub, 0]), : + ] + aberrations_num = aberrations_mn.shape[0] + sx, sy = probe_size dx, dy = probe_sampling + wavelength = electron_wavelength_angstrom(energy) + qx = xp.fft.fftfreq(sx, dx) qy = xp.fft.fftfreq(sy, dy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + alpha = xp.sqrt(qr2) * wavelength + theta = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + aberrations_basis = xp.ones((alpha.size, aberrations_num)) + + # Skip constant to avoid dividing by zero in normalization + for a0 in range(1, aberrations_num): + m, n, a = aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + + elif a == 0: + # cos coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + else: + # sin coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() - qxa, qya = xp.meshgrid(qx, qy, indexing="ij") - q2 = qxa**2 + qya**2 - theta = xp.arctan2(qya, qxa) - - basis = [] - index = [] - - for n in range(max_angular_order + 1): - for m in range((max_radial_order - n) // 2 + 1): - basis.append((q2 ** (m + n / 2) * np.cos(n * theta))) - index.append((m, n, 0)) - if n > 0: - basis.append((q2 ** (m + n / 2) * np.sin(n * theta))) - index.append((m, n, 1)) - - basis = xp.array(basis) + # global scaling + aberrations_basis *= 2 * np.pi / wavelength - return basis, index + return aberrations_basis, aberrations_mn def fit_aberration_surface( complex_probe, probe_sampling, + energy, max_angular_order, max_radial_order, xp=np, @@ -1592,21 +1630,47 @@ def fit_aberration_surface( unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True) unwrapped_angle = xp.asarray(unwrapped_angle).astype(xp.float32) - basis, _ = aberrations_basis_function( + raveled_basis, _ = aberrations_basis_function( complex_probe.shape, probe_sampling, + energy, max_angular_order, max_radial_order, xp=xp, ) - raveled_basis = basis.reshape((basis.shape[0], -1)) raveled_weights = probe_amp.ravel() - Aw = raveled_basis.T * raveled_weights[:, None] + Aw = raveled_basis * raveled_weights[:, None] bw = unwrapped_angle.ravel() * raveled_weights coeff = xp.linalg.lstsq(Aw, bw, rcond=None)[0] - fitted_angle = xp.tensordot(coeff, basis, axes=1) + fitted_angle = xp.tensordot(raveled_basis, coeff, axes=1).reshape(probe_angle.shape) return fitted_angle, coeff + + +def rotate_point(origin, point, angle): + """ + Rotate a point (x1, y1) counterclockwise by a given angle around + a given origin (x0, y0). + + Parameters + -------- + origin: 2-tuple of floats + (x0, y0) + point: 2-tuple of floats + (x1, y1) + angle: float (radians) + + Returns + -------- + rotated points (2-tuple) + + """ + ox, oy = origin + px, py = point + + qx = ox + np.cos(angle) * (px - ox) - np.sin(angle) * (py - oy) + qy = oy + np.sin(angle) * (px - ox) + np.cos(angle) * (py - oy) + return qx, qy diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 89f09606a..cfa017299 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -1,6 +1,5 @@ from matplotlib import cm, colors as mcolors, pyplot as plt import numpy as np -from matplotlib.colors import hsv_to_rgb from matplotlib.patches import Wedge from mpl_toolkits.axes_grid1 import make_axes_locatable from scipy.spatial import Voronoi @@ -17,6 +16,7 @@ ) from py4DSTEM.visualize.vis_grid import show_image_grid from py4DSTEM.visualize.vis_RQ import ax_addaxes, ax_addaxes_QtoR +from colorspacious import cspace_convert def show_elliptical_fit( @@ -937,15 +937,20 @@ def show_selected_dps( ) -def Complex2RGB(complex_data, vmin=None, vmax=None, hue_start=0, invert=False): +def Complex2RGB(complex_data, vmin=None, vmax=None, power=None, chroma_boost=1): """ complex_data (array): complex array to plot vmin (float) : minimum absolute value vmax (float) : maximum absolute value - hue_start (float) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme + power (float) : power to raise amplitude to + chroma_boost (float): boosts chroma for higher-contrast (~1-2.5) """ amp = np.abs(complex_data) + phase = np.angle(complex_data) + + if power is not None: + amp = amp**power + if np.isclose(np.max(amp), np.min(amp)): if vmin is None: vmin = 0 @@ -966,36 +971,40 @@ def Complex2RGB(complex_data, vmin=None, vmax=None, hue_start=0, invert=False): amp = np.where(amp < vmin, vmin, amp) amp = np.where(amp > vmax, vmax, amp) + amp = ((amp - vmin) / vmax).clip(1e-16, 1) + + J = amp * 61.5 # Note we restrict luminance to the monotonic chroma cutoff + C = np.minimum(chroma_boost * 98 * J / 123, 110) + h = np.rad2deg(phase) + 180 - phase = np.angle(complex_data) + np.deg2rad(hue_start) - amp /= np.max(amp) - rgb = np.zeros(phase.shape + (3,)) - rgb[..., 0] = 0.5 * (np.sin(phase) + 1) * amp - rgb[..., 1] = 0.5 * (np.sin(phase + np.pi / 2) + 1) * amp - rgb[..., 2] = 0.5 * (-np.sin(phase) + 1) * amp + JCh = np.stack((J, C, h), axis=-1) + rgb = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) - return 1 - rgb if invert else rgb + return rgb -def add_colorbar_arg(cax, vmin=None, vmax=None, hue_start=0, invert=False): +def add_colorbar_arg(cax, chroma_boost=1, c=49, j=61.5): """ - cax : axis to add cbar too - vmin (float) : minimum absolute value - vmax (float) : maximum absolute value - hue_start (float) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme + cax : axis to add cbar to + chroma_boost (float): boosts chroma for higher-contrast (~1-2.25) + c (float) : constant chroma value + j (float) : constant luminance value """ - z = np.exp(1j * np.linspace(-np.pi, np.pi, 200)) - rgb_vals = Complex2RGB(z, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert) + + h = np.linspace(0, 360, 256, endpoint=False) + J = np.full_like(h, j) + C = np.full_like(h, np.minimum(c * chroma_boost, 110)) + JCh = np.stack((J, C, h), axis=-1) + rgb_vals = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) newcmp = mcolors.ListedColormap(rgb_vals) norm = mcolors.Normalize(vmin=-np.pi, vmax=np.pi) - cb1 = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=newcmp), cax=cax) + cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=newcmp), cax=cax) - cb1.set_label("arg", rotation=0, ha="center", va="bottom") - cb1.ax.yaxis.set_label_coords(0.5, 1.01) - cb1.set_ticks(np.array([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])) - cb1.set_ticklabels( + cb.set_label("arg", rotation=0, ha="center", va="bottom") + cb.ax.yaxis.set_label_coords(0.5, 1.01) + cb.set_ticks(np.array([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])) + cb.set_ticklabels( [r"$-\pi$", r"$-\dfrac{\pi}{2}$", "$0$", r"$\dfrac{\pi}{2}$", r"$\pi$"] ) @@ -1004,13 +1013,13 @@ def show_complex( ar_complex, vmin=None, vmax=None, + power=None, + chroma_boost=1, cbar=True, scalebar=False, pixelunits="pixels", pixelsize=1, returnfig=False, - hue_start=0, - invert=False, **kwargs ): """ @@ -1023,13 +1032,13 @@ def show_complex( vmax (float, optional) : maximum absolute value if None, vmin/vmax are set to fractions of the distribution of pixel values in the array, e.g. vmin=0.02 will set the minumum display value to saturate the lower 2% of pixels - cbar (bool, optional) : if True, include color wheel + power (float,optional) : power to raise amplitude to + chroma_boost (float) : boosts chroma for higher-contrast (~1-2.25) + cbar (bool, optional) : if True, include color bar scalebar (bool, optional) : if True, adds scale bar pixelunits (str, optional) : units for scalebar pixelsize (float, optional) : size of one pixel in pixelunits for scalebar returnfig (bool, optional) : if True, the function returns the tuple (figure,axis) - hue_start (float, optional) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme Returns: if returnfig==False (default), the figure is plotted and nothing is returned. @@ -1044,7 +1053,7 @@ def show_complex( if isinstance(ar_complex, list): if isinstance(ar_complex[0], list): rgb = [ - Complex2RGB(ar, vmin, vmax, hue_start=hue_start, invert=invert) + Complex2RGB(ar, vmin, vmax, power=power, chroma_boost=chroma_boost) for sublist in ar_complex for ar in sublist ] @@ -1053,7 +1062,7 @@ def show_complex( else: rgb = [ - Complex2RGB(ar, vmin, vmax, hue_start=hue_start, invert=invert) + Complex2RGB(ar, vmin, vmax, power=power, chroma_boost=chroma_boost) for ar in ar_complex ] if len(rgb[0].shape) == 4: @@ -1064,7 +1073,9 @@ def show_complex( W = len(ar_complex) is_grid = True else: - rgb = Complex2RGB(ar_complex, vmin, vmax, hue_start=hue_start, invert=invert) + rgb = Complex2RGB( + ar_complex, vmin, vmax, power=power, chroma_boost=chroma_boost + ) if len(rgb.shape) == 4: is_grid = True H = 1 @@ -1115,37 +1126,18 @@ def show_complex( add_scalebar(ax, scalebar) # add color bar - if cbar == True: - ax0 = fig.add_axes([1, 0.35, 0.3, 0.3]) - - # create wheel - AA = 1000 - kx = np.fft.fftshift(np.fft.fftfreq(AA)) - ky = np.fft.fftshift(np.fft.fftfreq(AA)) - kya, kxa = np.meshgrid(ky, kx) - kra = (kya**2 + kxa**2) ** 0.5 - ktheta = np.arctan2(-kxa, kya) - ktheta = kra * np.exp(1j * ktheta) - - # convert to hsv - rgb = Complex2RGB(ktheta, 0, 0.4, hue_start=hue_start, invert=invert) - ind = kra > 0.4 - rgb[ind] = [1, 1, 1] - - # plot - ax0.imshow(rgb) - - # add axes - ax0.axhline(AA / 2, 0, AA, color="k") - ax0.axvline(AA / 2, 0, AA, color="k") - ax0.axis("off") - - label_size = 16 - - ax0.text(AA, AA / 2, 1, fontsize=label_size) - ax0.text(AA / 2, 0, "i", fontsize=label_size) - ax0.text(AA / 2, AA, "-i", fontsize=label_size) - ax0.text(0, AA / 2, -1, fontsize=label_size) - - if returnfig == True: + if cbar: + if is_grid: + for ax_flat in ax.flatten(): + divider = make_axes_locatable(ax_flat) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb) + else: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + fig.tight_layout() + + if returnfig: return fig, ax diff --git a/setup.py b/setup.py index c3cbbd151..d99bc978c 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,8 @@ "emdfile >= 0.0.13", "mpire >= 2.7.1", "threadpoolctl >= 3.1.0", + "pylops >= 2.1.0", + "colorspacious >= 1.1.2", ], extras_require={ "ipyparallel": ["ipyparallel >= 6.2.4", "dill >= 0.3.3"],