diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 29b76165c..c92f0d38b 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -14,6 +14,7 @@ from py4DSTEM.process.utils.cross_correlate import align_and_shift_images from py4DSTEM.process.utils.utils import electron_wavelength_angstrom +from py4DSTEM.process.utils import get_CoM from scipy.ndimage import gaussian_filter # fmt: off @@ -1277,24 +1278,20 @@ def nesterov_gamma(zero_indexed_iter_num): ) - - - - def regularize_probe_amp( probe_init, - width_max_pixels = 2.0, - enforce_constant_intensity = True, - return_coefs = False, - plot_result = False, - plot_polar = False, - cmap = 'turbo', - figsize = (5,5), - ): + width_max_pixels=2.0, + enforce_constant_intensity=True, + return_coefs=False, + plot_result=False, + plot_polar=False, + cmap="turbo", + figsize=(5, 5), +): """ - Assume that the probe is centered in Fourier space. Note we + Assumes the probe is corner-centered in Fourier space. Note we re-implemented the polar/cartesian transforms here for portability. - + Parameters -------- probe_init: np.array @@ -1323,50 +1320,36 @@ def regularize_probe_amp( probe_corr: np.array 2D complex image of the corrected probe in Fourier space. coefs_all: np.array (optional) - coefficients for the - - - + coefficients for the """ # Get probe intensity probe_amp = np.abs(probe_init) probe_int = probe_amp**2 - # coordinates - xa,ya = np.meshgrid( - np.arange(probe_init.shape[0]), - np.arange(probe_init.shape[1]), - indexing = 'ij', - ) - # Center of mass for probe intensity - int_total = np.sum(probe_int) - xy_center = ( - np.sum(probe_int * xa) / int_total, - np.sum(probe_int * ya) / int_total, - ) + xy_center = get_CoM(probe_int, device="cpu", corner_centered=True) # Convert intensity to polar coordinates - polar_int = im_cart_to_polar( + polar_int = cartesian_to_polar_transform_2Ddata( probe_int, - xy_center = xy_center, - ) + xy_center=xy_center, + ) # Fit corrected probe intensity radius = np.arange(polar_int.shape[1]) # estimate initial parameters - sub = polar_int > (np.max(polar_int)*0.5) + sub = polar_int > (np.max(polar_int) * 0.5) sig_0 = np.mean(polar_int[sub]) - rad_0 = np.max(np.argwhere(np.sum(sub,axis=0))) + rad_0 = np.max(np.argwhere(np.sum(sub, axis=0))) width = width_max_pixels * 0.5 # init - coefs_all = np.zeros((polar_int.shape[0],3)) - coefs_all[:,0] = sig_0 - coefs_all[:,1] = rad_0 - coefs_all[:,2] = width + coefs_all = np.zeros((polar_int.shape[0], 3)) + coefs_all[:, 0] = sig_0 + coefs_all[:, 1] = rad_0 + coefs_all[:, 2] = width # bounds lb = (0.0, 0.0, 1e-4) @@ -1375,80 +1358,79 @@ def regularize_probe_amp( # refine parameters, generate polar image polar_fit = np.zeros_like(polar_int) for a0 in range(polar_int.shape[0]): - coefs_all[a0,:] = curve_fit( - step_model, - radius, - polar_int[a0,:], - p0 = coefs_all[a0,:], - xtol = 1e-12, - bounds = (lb,ub), - )[0] - polar_fit[a0,:] = step_model( + coefs_all[a0, :] = curve_fit( + step_model, radius, - coefs_all[a0,:]) + polar_int[a0, :], + p0=coefs_all[a0, :], + xtol=1e-12, + bounds=(lb, ub), + )[0] + polar_fit[a0, :] = step_model(radius, coefs_all[a0, :]) if enforce_constant_intensity: # Compute best-fit constant intensity inside probe, update bounds - sig_0 = np.median(coefs_all[:,0]) - coefs_all[:,0] = sig_0 - lb = (sig_0-1e-8, 0.0, 1e-4) - ub = (sig_0+1e-8, np.inf, width_max_pixels) + sig_0 = np.median(coefs_all[:, 0]) + coefs_all[:, 0] = sig_0 + lb = (sig_0 - 1e-8, 0.0, 1e-4) + ub = (sig_0 + 1e-8, np.inf, width_max_pixels) # refine parameters, generate polar image polar_int_corr = np.zeros_like(polar_int) for a0 in range(polar_int.shape[0]): - coefs_all[a0,:] = curve_fit( - step_model, - radius, - polar_int[a0,:], - p0 = coefs_all[a0,:], - xtol = 1e-12, - bounds = (lb,ub), - )[0] - polar_int_corr[a0,:] = step_model( + coefs_all[a0, :] = curve_fit( + step_model, radius, - coefs_all[a0,:]) + polar_int[a0, :], + p0=coefs_all[a0, :], + xtol=1e-12, + bounds=(lb, ub), + )[0] + polar_int_corr[a0, :] = step_model(radius, coefs_all[a0, :]) else: polar_int_corr = polar_fit # Convert back to cartesian coordinates - int_corr = im_polar_to_cart( + int_corr = polar_to_cartesian_transform_2Ddata( polar_int_corr, - xy_size = probe_init.shape, - xy_center = xy_center, - ) + xy_size=probe_init.shape, + xy_center=xy_center, + ) # Assemble output probe - probe_corr = np.sqrt(np.maximum(int_corr,0)) \ - * np.exp(1j*np.angle(probe_init)) + probe_corr = np.sqrt(np.maximum(int_corr, 0)) * np.exp(1j * np.angle(probe_init)) # plotting if plot_result: - fig,ax = plt.subplots(figsize = (figsize[0]*2, figsize[1])) + fig, ax = plt.subplots(figsize=(figsize[0] * 2, figsize[1])) ax.imshow( - np.hstack(( - probe_int, - int_corr, - )), - cmap = 'turbo', - ) + np.hstack( + ( + np.fft.fftshift(probe_int), + np.fft.fftshift(int_corr), + ) + ), + cmap="turbo", + ) if plot_polar: - fig,ax = plt.subplots(figsize = figsize) + fig, ax = plt.subplots(figsize=figsize) ax.imshow( - np.hstack(( - polar_int, - polar_fit, - polar_int_corr, - )), - cmap = 'turbo', - ) + np.hstack( + ( + polar_int, + polar_fit, + polar_int_corr, + ) + ), + cmap="turbo", + ) if return_coefs: return probe_corr, coefs_all else: return probe_corr - + def step_model(radius, *coefs): coefs = np.squeeze(np.array(coefs)) @@ -1460,103 +1442,118 @@ def step_model(radius, *coefs): return sig_0 * np.clip((rad_0 - radius) / width, 0.0, 1.0) -def im_cart_to_polar( +def cartesian_to_polar_transform_2Ddata( im_cart, xy_center, - num_theta_bins = 180, - radius_max = None, - ): + num_theta_bins=180, + radius_max=None, +): """ Quick cartesian to polar conversion. """ # coordinates - xa,ya = np.meshgrid( - np.arange(im_cart.shape[0]), - np.arange(im_cart.shape[1]), - indexing = 'ij', - ) if radius_max is None: - radius_max = np.ceil(np.sqrt(np.sum( - np.array(im_cart.shape).astype('float')**2 - )) / 2.0).astype('int') + radius_max = np.min(np.array(im_cart.shape) // 2) + r = np.arange(radius_max) t = np.linspace( 0, - 2.0*np.pi, + 2.0 * np.pi, num_theta_bins, - endpoint = False, - ) - ra,ta = np.meshgrid(r,t) + endpoint=False, + ) + ra, ta = np.meshgrid(r, t) # resampling coordinates - x = (ra * np.cos(ta) + xy_center[0]) - y = (ra * np.sin(ta) + xy_center[1]) - xf = np.floor(x).astype('int') - yf = np.floor(y).astype('int') + x = ra * np.cos(ta) + xy_center[0] + y = ra * np.sin(ta) + xy_center[1] + + xf = np.floor(x).astype("int") + yf = np.floor(y).astype("int") dx = x - xf dy = y - yf # resample image - im_polar = \ - im_cart.ravel()[np.ravel_multi_index( - (xf, yf), - im_cart.shape, - mode='clip', - )] * (1-dx) * (1-dy) + \ - im_cart.ravel()[np.ravel_multi_index( - (xf+1, yf), - im_cart.shape, - mode='clip', - )] * ( dx) * (1-dy) + \ - im_cart.ravel()[np.ravel_multi_index( - (xf, yf+1), - im_cart.shape, - mode='clip', - )] * (1-dx) * ( dy) + \ - im_cart.ravel()[np.ravel_multi_index( - (xf+1, yf+1), - im_cart.shape, - mode='clip', - )] * ( dx) * ( dy) + im_polar = ( + im_cart.ravel()[ + np.ravel_multi_index( + (xf, yf), + im_cart.shape, + mode="wrap", + ) + ] + * (1 - dx) + * (1 - dy) + + im_cart.ravel()[ + np.ravel_multi_index( + (xf + 1, yf), + im_cart.shape, + mode="wrap", + ) + ] + * (dx) + * (1 - dy) + + im_cart.ravel()[ + np.ravel_multi_index( + (xf, yf + 1), + im_cart.shape, + mode="wrap", + ) + ] + * (1 - dx) + * (dy) + + im_cart.ravel()[ + np.ravel_multi_index( + (xf + 1, yf + 1), + im_cart.shape, + mode="wrap", + ) + ] + * (dx) + * (dy) + ) return im_polar -def im_polar_to_cart( +def polar_to_cartesian_transform_2Ddata( im_polar, xy_size, xy_center, - ): +): """ Quick cartesian to polar conversion. """ # coordinates - xa,ya = np.meshgrid( - np.arange(xy_size[0]) - xy_center[0], - np.arange(xy_size[1]) - xy_center[1], - indexing = 'ij', - ) - ra = np.sqrt(xa**2 + ya**2) - ta = np.arctan2(ya,xa) - t = np.linspace(0,2*np.pi,im_polar.shape[0],endpoint = False) + sx, sy = xy_size + cx, cy = xy_center + + x = np.fft.fftfreq(sx, d=1 / sx) + y = np.fft.fftfreq(sy, d=1 / sy) + xa, ya = np.meshgrid(x, y, indexing="ij") + ra = np.hypot(xa - cx, ya - cy) + ta = np.arctan2(ya, xa) + + t = np.linspace(0, 2 * np.pi, im_polar.shape[0], endpoint=False) t_step = t[1] - t[0] # resampling coordinates t_ind = ta / t_step r_ind = ra.copy() - tf = np.floor(t_ind).astype('int') - rf = np.floor(r_ind).astype('int') + tf = np.floor(t_ind).astype("int") + rf = np.floor(r_ind).astype("int") dt = t_ind - tf dr = r_ind - rf # resample image - im_cart = im_polar.ravel()[np.ravel_multi_index( - (np.mod(tf, im_polar.shape[0]), rf), + im_cart = im_polar.ravel()[ + np.ravel_multi_index( + (tf, rf), im_polar.shape, - mode='clip', - )] + mode=("wrap", "clip"), + ) + ] return im_cart -