Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Double double, aberrations, bugs, and trouble #539

Merged
merged 71 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
6db566e
start for depth profile
smribet Aug 4, 2023
4eb461d
adding real-space kde upsampling
gvarnavi Aug 8, 2023
0176914
Merge remote-tracking branch 'origin/phase_contrast' into phase_contrast
smribet Aug 8, 2023
8c61a5b
more depth profile
smribet Aug 8, 2023
22d1bb4
saving error
smribet Aug 8, 2023
5f919bd
small name changes
smribet Aug 8, 2023
68b31a7
single slice tv denoise
smribet Aug 17, 2023
5acdd58
multislice tv denoise... will test more before adding to other classe…
smribet Aug 17, 2023
aac2d12
minor tv bugfix
gvarnavi Aug 21, 2023
da85db0
improvements for depth plotting
smribet Aug 21, 2023
a582212
tv fixes
smribet Aug 22, 2023
eaa3699
everyone gets TV denoising
smribet Aug 25, 2023
cc36490
subpixel alignment phase correct, part 1
smribet Aug 29, 2023
41f319d
removing print statement
smribet Aug 29, 2023
95beba5
parallax plotting fix
smribet Sep 7, 2023
c25e242
generalizing overlap tomo to orientation matrices
gvarnavi Sep 20, 2023
a7b55aa
Merge branch 'phase_contrast' of github.com:py4dstem/py4DSTEM into ph…
gvarnavi Sep 20, 2023
31b5525
black formatting
gvarnavi Sep 20, 2023
d4363f4
flake8 6.1.0 found some more issues
gvarnavi Sep 20, 2023
e60071a
adding mixed-state multi-slice ptycho class
gvarnavi Sep 20, 2023
807ac15
small reset bug
smribet Sep 22, 2023
8b35dbb
fixed NaN bug
gvarnavi Sep 23, 2023
9f24d5c
Merge branch 'phase_contrast' of github.com:py4dstem/py4DSTEM into ph…
gvarnavi Sep 23, 2023
6cb62f1
changed complex plotting
gvarnavi Sep 23, 2023
ad79416
updated complex plotting phase calls
gvarnavi Sep 24, 2023
356a3a1
adding complex CoM plotting and various dpc plotting bugs
gvarnavi Sep 24, 2023
2bc9da9
parallax descan correct
smribet Sep 30, 2023
3a99d5a
complex plotting improvements, formatting
gvarnavi Sep 30, 2023
d4099ea
change to fitted intensities
smribet Oct 4, 2023
8c6be92
Merge remote-tracking branch 'origin/phase_contrast' into phase_contrast
smribet Oct 4, 2023
f2c21d5
preprocessing dtype bug
smribet Oct 8, 2023
7c3a0d8
adding tilt to propagators
smribet Oct 10, 2023
23204aa
single slice crop patterns
smribet Oct 12, 2023
84a2067
tv_denoise typo
gvarnavi Oct 13, 2023
0978c4b
revisting casting inconsistencies
gvarnavi Oct 13, 2023
41069fd
Merge remote-tracking branch 'origin/phase_contrast' into phase_contrast
smribet Oct 14, 2023
c9ac5db
crop pattern option for all classes
smribet Oct 14, 2023
00991c9
fix for gpu
smribet Oct 14, 2023
31af429
cleaned up parallax descan
gvarnavi Oct 16, 2023
43289e3
added support for float upsampling
gvarnavi Oct 16, 2023
35d076f
making descan correction the default
gvarnavi Oct 16, 2023
d2a198e
Merge remote-tracking branch 'origin/phase_contrast' into phase_contrast
smribet Oct 16, 2023
2e25fe1
Merge remote-tracking branch 'origin/phase_contrast' into phase_contrast
smribet Oct 16, 2023
c06ca46
removing redundant if statement
gvarnavi Oct 16, 2023
9be2e9a
removed separate ctf corrections and other subpixel improvements
gvarnavi Oct 16, 2023
ab76946
added read-write functionality to parralax
gvarnavi Oct 16, 2023
06e1837
Starting on CTF fitting
cophus Oct 16, 2023
593f07d
Adding more parts of parallax CTF fitting
cophus Oct 17, 2023
e5d7425
added chroma_boost for show_complex
gvarnavi Oct 17, 2023
62c4cae
Working on CTF
cophus Oct 17, 2023
ea070f8
It works!
cophus Oct 17, 2023
b2cbede
Updating outputs
cophus Oct 17, 2023
2705514
Adding outputs, plotting
cophus Oct 17, 2023
2e2690c
Merge pull request #536 from cophus/phase_contrast
gvarnavi Oct 18, 2023
7eae948
finally works
gvarnavi Oct 19, 2023
d1f6efb
some support for aberration correct
gvarnavi Oct 19, 2023
de75b51
small bug fixes
gvarnavi Oct 19, 2023
52ee427
cleaned up parallax
gvarnavi Oct 20, 2023
54d5859
ptycho new aberration formalism
gvarnavi Oct 20, 2023
9865f39
adding chroma_boost defaults
gvarnavi Oct 20, 2023
020e170
formatted, linted, isorted
gvarnavi Oct 20, 2023
f9ae5e1
fixed some manual conflicts
gvarnavi Oct 20, 2023
9806e27
fixing radial order accounting
gvarnavi Oct 20, 2023
2eca4b7
make lint happy I hope!
smribet Oct 20, 2023
71b8f6d
fix extent for ms depth sectioning
smribet Oct 21, 2023
dfc312b
small fixes
smribet Oct 21, 2023
b0e2c42
fix for ptycho aberration fit
smribet Oct 21, 2023
17dd9a2
black format
smribet Oct 21, 2023
ada4d4d
fixed ptycho fitting, added transpose flag in parallax
gvarnavi Oct 22, 2023
9529945
added force_transpose option for other two aberration fit methods
gvarnavi Oct 22, 2023
43220d0
read-write device bugfix
gvarnavi Oct 22, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 8 additions & 22 deletions py4DSTEM/process/phase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,14 @@
_emd_hook = True

from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction
from py4DSTEM.process.phase.iterative_mixedstate_ptychography import (
MixedstatePtychographicReconstruction,
)
from py4DSTEM.process.phase.iterative_multislice_ptychography import (
MultislicePtychographicReconstruction,
)
from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import (
OverlapMagneticTomographicReconstruction,
)
from py4DSTEM.process.phase.iterative_overlap_tomography import (
OverlapTomographicReconstruction,
)
from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction
from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction
from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction
from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import OverlapMagneticTomographicReconstruction
from py4DSTEM.process.phase.iterative_overlap_tomography import OverlapTomographicReconstruction
from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction
from py4DSTEM.process.phase.iterative_simultaneous_ptychography import (
SimultaneousPtychographicReconstruction,
)
from py4DSTEM.process.phase.iterative_singleslice_ptychography import (
SingleslicePtychographicReconstruction,
)
from py4DSTEM.process.phase.parameter_optimize import (
OptimizationParameter,
PtychographyOptimizer,
)
from py4DSTEM.process.phase.iterative_simultaneous_ptychography import SimultaneousPtychographicReconstruction
from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction
from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer

# fmt: on
214 changes: 181 additions & 33 deletions py4DSTEM/process/phase/iterative_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,53 @@ def attach_datacube(self, datacube: DataCube):
self._datacube = datacube
return self

def reinitialize_parameters(self, device: str = None, verbose: bool = None):
"""
Reinitializes common parameters. This is useful when loading a previously-saved
reconstruction (which set device='cpu' and verbose=True for compatibility) ,
using different initialization parameters.

Parameters
----------
device: str, optional
If not None, imports and assigns appropriate device modules
verbose: bool, optional
If not None, sets the verbosity to verbose

Returns
--------
self: PhaseReconstruction
Self to enable chaining
"""

if device is not None:
if device == "cpu":
self._xp = np
self._asnumpy = np.asarray
from scipy.ndimage import gaussian_filter

self._gaussian_filter = gaussian_filter
from scipy.special import erf

self._erf = erf
elif device == "gpu":
self._xp = cp
self._asnumpy = cp.asnumpy
from cupyx.scipy.ndimage import gaussian_filter

self._gaussian_filter = gaussian_filter
from cupyx.scipy.special import erf

self._erf = erf
else:
raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}")
self._device = device

if verbose is not None:
self._verbose = verbose

return self

def set_save_defaults(
self,
save_datacube: bool = False,
Expand Down Expand Up @@ -278,7 +325,9 @@ def _extract_intensities_and_calibrations_from_datacube(
"""

# Copies intensities to device casting to float32
intensities = datacube.data
xp = self._xp

intensities = xp.asarray(datacube.data, dtype=xp.float32)
self._grid_scan_shape = intensities.shape[:2]

# Extracts calibrations
Expand All @@ -295,13 +344,14 @@ def _extract_intensities_and_calibrations_from_datacube(
if require_calibrations:
raise ValueError("Real-space calibrations must be given in 'A'")

warnings.warn(
(
"Iterative reconstruction will not be quantitative unless you specify "
"real-space calibrations in 'A'"
),
UserWarning,
)
if self._verbose:
warnings.warn(
(
"Iterative reconstruction will not be quantitative unless you specify "
"real-space calibrations in 'A'"
),
UserWarning,
)

self._scan_sampling = (1.0, 1.0)
self._scan_units = ("pixels",) * 2
Expand Down Expand Up @@ -359,13 +409,14 @@ def _extract_intensities_and_calibrations_from_datacube(
"Reciprocal-space calibrations must be given in in 'A^-1' or 'mrad'"
)

warnings.warn(
(
"Iterative reconstruction will not be quantitative unless you specify "
"appropriate reciprocal-space calibrations"
),
UserWarning,
)
if self._verbose:
warnings.warn(
(
"Iterative reconstruction will not be quantitative unless you specify "
"appropriate reciprocal-space calibrations"
),
UserWarning,
)

self._angular_sampling = (1.0, 1.0)
self._angular_units = ("pixels",) * 2
Expand Down Expand Up @@ -448,8 +499,6 @@ def _calculate_intensities_center_of_mass(
xp = self._xp
asnumpy = self._asnumpy

intensities = xp.asarray(intensities, dtype=xp.float32)

# for ptycho
if com_measured:
com_measured_x, com_measured_y = com_measured
Expand Down Expand Up @@ -484,22 +533,27 @@ def _calculate_intensities_center_of_mass(
)

if com_shifts is None:
com_measured_x_np = asnumpy(com_measured_x)
com_measured_y_np = asnumpy(com_measured_y)
finite_mask = np.isfinite(com_measured_x_np)

com_shifts = fit_origin(
(asnumpy(com_measured_x), asnumpy(com_measured_y)),
(com_measured_x_np, com_measured_y_np),
fitfunction=fit_function,
mask=finite_mask,
)

# Fit function to center of mass
com_fitted_x = xp.asarray(com_shifts[0], dtype=xp.float32)
com_fitted_y = xp.asarray(com_shifts[1], dtype=xp.float32)

# fix CoM units
com_normalized_x = (com_measured_x - com_fitted_x) * self._reciprocal_sampling[
0
]
com_normalized_y = (com_measured_y - com_fitted_y) * self._reciprocal_sampling[
1
]
com_normalized_x = (
xp.nan_to_num(com_measured_x - com_fitted_x) * self._reciprocal_sampling[0]
)
com_normalized_y = (
xp.nan_to_num(com_measured_y - com_fitted_y) * self._reciprocal_sampling[1]
)

return (
com_measured_x,
Expand Down Expand Up @@ -1077,6 +1131,7 @@ def _normalize_diffraction_intensities(
diffraction_intensities,
com_fitted_x,
com_fitted_y,
crop_patterns,
):
"""
Fix diffraction intensities CoM, shift to origin, and take square root
Expand All @@ -1089,6 +1144,9 @@ def _normalize_diffraction_intensities(
Best fit horizontal center of mass gradient
com_fitted_y: (Rx,Ry) xp.ndarray
Best fit vertical center of mass gradient
crop_patterns: bool
if True, crop patterns to avoid wrap around of patterns
when centering

Returns
-------
Expand All @@ -1101,13 +1159,46 @@ def _normalize_diffraction_intensities(
xp = self._xp
mean_intensity = 0

amplitudes = xp.zeros_like(diffraction_intensities)
region_of_interest_shape = diffraction_intensities.shape[-2:]
diffraction_intensities = self._asnumpy(diffraction_intensities)
if crop_patterns:
crop_x = int(
np.minimum(
diffraction_intensities.shape[2] - com_fitted_x.max(),
com_fitted_x.min(),
)
)
crop_y = int(
np.minimum(
diffraction_intensities.shape[3] - com_fitted_y.max(),
com_fitted_y.min(),
)
)

crop_w = np.minimum(crop_y, crop_x)
region_of_interest_shape = (crop_w * 2, crop_w * 2)
amplitudes = np.zeros(
(
diffraction_intensities.shape[0],
diffraction_intensities.shape[1],
crop_w * 2,
crop_w * 2,
),
dtype=np.float32,
)

crop_mask = np.zeros(diffraction_intensities.shape[-2:], dtype=np.bool_)
crop_mask[:crop_w, :crop_w] = True
crop_mask[-crop_w:, :crop_w] = True
crop_mask[:crop_w:, -crop_w:] = True
crop_mask[-crop_w:, -crop_w:] = True
self._crop_mask = crop_mask

else:
region_of_interest_shape = diffraction_intensities.shape[-2:]
amplitudes = np.zeros(diffraction_intensities.shape, dtype=np.float32)

com_fitted_x = self._asnumpy(com_fitted_x)
com_fitted_y = self._asnumpy(com_fitted_y)
diffraction_intensities = self._asnumpy(diffraction_intensities)
amplitudes = self._asnumpy(amplitudes)

for rx in range(diffraction_intensities.shape[0]):
for ry in range(diffraction_intensities.shape[1]):
Expand All @@ -1119,16 +1210,71 @@ def _normalize_diffraction_intensities(
device="cpu",
)

if crop_patterns:
intensities = intensities[crop_mask].reshape(
region_of_interest_shape
)

mean_intensity += np.sum(intensities)
amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0))

amplitudes = xp.asarray(amplitudes, dtype=xp.float32)

amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape)
amplitudes = xp.asarray(amplitudes)
mean_intensity /= amplitudes.shape[0]

return amplitudes, mean_intensity

def show_complex_CoM(
self,
com=None,
cbar=True,
scalebar=True,
pixelsize=None,
pixelunits=None,
**kwargs,
):
"""
Plot complex-valued CoM image

Parameters
----------

com = (CoM_x, CoM_y) tuple
If None is specified, uses (self.com_x, self.com_y) instead
cbar: bool, optional
if True, adds colorbar
scalebar: bool, optional
if True, adds scalebar to probe
pixelunits: str, optional
units for scalebar, default is A
pixelsize: float, optional
default is scan sampling
"""

if com is None:
com = (self.com_x, self.com_y)

if pixelsize is None:
pixelsize = self._scan_sampling[0]
if pixelunits is None:
pixelunits = r"$\AA$"

figsize = kwargs.pop("figsize", (6, 6))
fig, ax = plt.subplots(figsize=figsize)

complex_com = com[0] + 1j * com[1]

show_complex(
complex_com,
cbar=cbar,
figax=(fig, ax),
scalebar=scalebar,
pixelsize=pixelsize,
pixelunits=pixelunits,
ticks=False,
**kwargs,
)


class PtychographicReconstruction(PhaseReconstruction, PtychographicConstraints):
"""
Expand Down Expand Up @@ -1309,10 +1455,10 @@ def _get_constructor_args(cls, group):
"object_type": instance_md["object_type"],
"semiangle_cutoff": instance_md["semiangle_cutoff"],
"rolloff": instance_md["rolloff"],
"verbose": instance_md["verbose"],
"name": instance_md["name"],
"device": instance_md["device"],
"polar_parameters": polar_params,
"verbose": True, # for compatibility
"device": "cpu", # for compatibility
}

class_specific_kwargs = {}
Expand Down Expand Up @@ -2109,6 +2255,7 @@ def show_fourier_probe(
pixelunits = r"$\AA^{-1}$"

figsize = kwargs.pop("figsize", (6, 6))
chroma_boost = kwargs.pop("chroma_boost", 2)

fig, ax = plt.subplots(figsize=figsize)
show_complex(
Expand All @@ -2119,6 +2266,7 @@ def show_fourier_probe(
pixelsize=pixelsize,
pixelunits=pixelunits,
ticks=False,
chroma_boost=chroma_boost,
**kwargs,
)

Expand All @@ -2142,7 +2290,7 @@ def show_object_fft(self, obj=None, **kwargs):
vmax = kwargs.pop("vmax", 1)
power = kwargs.pop("power", 0.2)

pixelsize = 1 / (object_fft.shape[0] * self.sampling[0])
pixelsize = 1 / (object_fft.shape[1] * self.sampling[1])
show(
object_fft,
figsize=figsize,
Expand Down
Loading