diff --git a/README.rst b/README.rst index 59531ac4..f54dffe6 100644 --- a/README.rst +++ b/README.rst @@ -1,25 +1,6 @@ SigPy ===== -.. image:: https://img.shields.io/badge/License-BSD%203--Clause-blue.svg - :target: https://opensource.org/licenses/BSD-3-Clause - -.. image:: https://travis-ci.com/mikgroup/sigpy.svg?branch=master - :target: https://travis-ci.com/mikgroup/sigpy - -.. image:: https://readthedocs.org/projects/sigpy/badge/?version=latest - :target: https://sigpy.readthedocs.io/en/latest/?badge=latest - :alt: Documentation Status - -.. image:: https://codecov.io/gh/mikgroup/sigpy/branch/master/graph/badge.svg - :target: https://codecov.io/gh/mikgroup/sigpy - -.. image:: https://zenodo.org/badge/139635485.svg - :target: https://zenodo.org/badge/latestdoi/139635485 - - -`Source Code `_ | `Documentation `_ | `MRI Recon Tutorial `_ | `MRI Pulse Design Tutorial `_ - SigPy is a package for signal processing, with emphasis on iterative methods. It is built to operate directly on NumPy arrays on CPU and CuPy arrays on GPU. SigPy also provides several domain-specific submodules: ``sigpy.plot`` for multi-dimensional array plotting, ``sigpy.mri`` for MRI reconstruction, and ``sigpy.mri.rf`` for MRI pulse design. Installation @@ -48,7 +29,7 @@ SigPy can also be installed through ``pip``:: # (optional for plot support) pip install matplotlib # (optional for CUDA support) pip install cupy # (optional for MPI support) pip install mpi4py - + Installation for Developers *************************** @@ -56,7 +37,7 @@ If you want to contribute to the SigPy source code, we recommend you install it cd /path/to/sigpy pip install -e . - + To run tests and contribute, we recommend installing the following packages:: pip install coverage ruff sphinx sphinx_rtd_theme black isort @@ -101,4 +82,3 @@ Want to do machine learning without giving up signal processing? SigPy has conve x_torch = sigpy.to_pytorch(x) A_torch = sigpy.to_pytorch_function(A) - diff --git a/conda.recipe/install.sh b/conda.recipe/install.sh new file mode 100644 index 00000000..5bea30b3 --- /dev/null +++ b/conda.recipe/install.sh @@ -0,0 +1,27 @@ +# to install sigpy from scratch: + +git clone https://github.com/ZhengguoTan/sigpy.git sigpy_github + +cd sigpy_github + +conda create -n sigpy_github python=3.10 +conda activate sigpy_github + +conda install -c anaconda pip + +python -m pip install torch torchvision torchaudio + +python -m pip install tqdm +python -m pip install numba +python -m pip install scipy +python -m pip install pywavelets +python -m pip install h5py +python -m pip install matplotlib + +# please - +# (1) check your cuda version here +# (2) log into gpu machine when you are in hpc +# https://docs.cupy.dev/en/stable/install.html +pip install cupy-cuda11x + +pip install -e . \ No newline at end of file diff --git a/sigpy/__init__.py b/sigpy/__init__.py index a053a4fe..9680b916 100644 --- a/sigpy/__init__.py +++ b/sigpy/__init__.py @@ -8,23 +8,13 @@ that can be used in conjuction with Alg to form an App. """ -from sigpy import ( - alg, - app, - backend, - block, - config, - conv, - fourier, - interp, - linop, - prox, - pytorch, - sim, - thresh, - util, - wavelet, -) + +from .version import __version__ # noqa +from sigpy import alg, app, config, linop, prox, nlop + +from sigpy import (backend, block, conv, interp, + fourier, pytorch, sim, thresh, + util, wavelet) from sigpy.backend import * # noqa from sigpy.block import * # noqa from sigpy.conv import * # noqa @@ -36,9 +26,7 @@ from sigpy.util import * # noqa from sigpy.wavelet import * # noqa -from .version import __version__ # noqa - -__all__ = ["alg", "app", "config", "linop", "prox"] +__all__ = ['alg', 'app', 'config', 'linop', 'prox', 'nlop'] __all__.extend(backend.__all__) __all__.extend(block.__all__) __all__.extend(conv.__all__) diff --git a/sigpy/alg.py b/sigpy/alg.py index fce675ad..228b86b1 100644 --- a/sigpy/alg.py +++ b/sigpy/alg.py @@ -228,12 +228,13 @@ class ConjugateGradient(Alg): """ - def __init__(self, A, b, x, P=None, max_iter=100, tol=0): + def __init__(self, A, b, x, P=None, max_iter=100, tol=0, verbose=False): self.A = A self.b = b self.P = P self.x = x self.tol = tol + self.verbose = verbose self.device = backend.get_device(x) with self.device: xp = self.device.xp @@ -278,7 +279,12 @@ def _update(self): util.xpay(self.p, beta, z) self.rzold = rznew - self.resid = self.rzold.item() ** 0.5 + self.resid = self.rzold.item()**0.5 + + if self.verbose: + print(" cg iter: " + "%2d" % (self.iter) + + "; resid: " + "%13.6f" % (self.resid) + + "; norm: " + "%13.6f" % (xp.linalg.norm(self.x.flatten()))) def _done(self): return ( @@ -840,6 +846,106 @@ def _done(self): return self.iter >= self.max_iter or self.residual <= self.tol +class IRGNM(Alg): + r"""Iteratively Regularized Gauss-Newton Method (IRGNM) + + Args: + A (nlop): Non-linear forward model. + y (array): Observation. + x (array): Variable. + x0 (array): bias for L2 regularization. + max_iter (int): maximum number of iterations. + alpha (int): regularization strength. + redu (float): reduction factor along iterations. + inner_iter (int): maximum number of inner iterations. + inner_tol (float): tolerance for the inner iterations. + + References: + Bauer F., Kannengiesser S. (2007). + An alternative approach to the image reconstruction + for parallel data acquisition in MRI. + Math. Meth. Appl. Sci. 30, 1437-1451. + + Uecker M., Hohage T., Block K. T., Frahm J. (2008) + Image reconstruction by regularized nonlinear inversion + -- joint estimation of coil sensitivities and image content. + Magn. Reson. Med. 60, 674-682. + + Tan Z., Roeloffs V., Voit D., Joseph A. A., Untenberger M., + Merboldt K. D., Frahm J. (2016). + Model-based reconstruction for real-time phase-contrast flow MRI: + Improved spatiotemporal accuracy. + Magn. Reson. Med. 77, 1082-1093. + + """ + def __init__(self, A, y, x, x0=None, + max_iter=6, alpha=1, redu=2, + inner_iter=100, inner_tol=0.01, verbose=True): + self.A = A + self.y = y + self.device = backend.get_device(y) + + # outer iteration + self.max_iter = max_iter + self.alpha = alpha + self.redu = redu + + # inner iteration + self.inner_iter = inner_iter + self.inner_tol = inner_tol + + self.verbose = verbose + + xp = self.device.xp + with self.device: + self.x = sp.to_device(x, device=self.device) + + self.x0 = xp.zeros(x.shape, dtype=y.dtype) + if x0 is not None: + self.x0 = sp.to_device(x0, device=self.device) + + super().__init__(max_iter) + + def _update(self): + xp = self.device.xp + + with self.device: + dx = xp.zeros_like(self.x) + + r = self.y - self.A(self.x) + + resid = xp.linalg.norm(r).item() + + if self.verbose: + print("gn iter: " + "%2d" % (self.iter) + + "; alpha: " + "%.6f" % (self.alpha) + + "; resid: " + "%.6f" % (resid)) + + p = self.A.adjoint(self.x, r) + p += self.alpha * (self.x0 - self.x) + + # update dx + def AHA(x): + return self.A.adjoint(self.x, self.A.derivative(self.x, x)) \ + + self.alpha * x + + inner_tol = self.inner_tol * xp.linalg.norm(p).item() + + inner_alg = ConjugateGradient(AHA, p, dx, + max_iter=self.inner_iter, + tol=inner_tol) + while not inner_alg.done(): + inner_alg.update() + + # update x + self.x += 1. * dx + self.alpha /= self.redu + + def _done(self): + over_iter = self.iter >= self.max_iter + return over_iter + + class GerchbergSaxton(Alg): """Gerchberg-Saxton method, also called the variable exchange method. Iterative method for recovery of a signal from the amplitude of linear diff --git a/sigpy/app.py b/sigpy/app.py index 79c54d45..d7069719 100644 --- a/sigpy/app.py +++ b/sigpy/app.py @@ -8,13 +8,10 @@ from tqdm.auto import tqdm from sigpy import backend, linop, prox, util -from sigpy.alg import ( - ADMM, - ConjugateGradient, - GradientMethod, - PowerMethod, - PrimalDualHybridGradient, -) +from sigpy.alg import (PowerMethod, GradientMethod, ADMM, IRGNM, + ConjugateGradient, PrimalDualHybridGradient) + +import numpy as np class App(object): @@ -185,34 +182,19 @@ class LinearLeastSquares(App): """ - def __init__( - self, - A, - y, - x=None, - proxg=None, - lamda=0, - G=None, - g=None, - z=None, - solver=None, - max_iter=100, - P=None, - alpha=None, - max_power_iter=30, - accelerate=True, - tau=None, - sigma=None, - rho=1, - max_cg_iter=10, - tol=0, - save_objective_values=False, - show_pbar=True, - leave_pbar=True, - ): + def __init__(self, A, y, x=None, proxg=None, + lamda=0, G=None, g=None, z=None, + solver=None, max_iter=100, scale=1, + P=None, alpha=None, max_power_iter=30, accelerate=True, + tau=None, sigma=None, + rho=1, max_cg_iter=10, tol=0, + save_objective_values=False, + show_pbar=True, leave_pbar=True, + verbose=False): self.A = A self.y = y self.x = x + self.scale = scale self.proxg = proxg self.lamda = lamda self.G = G @@ -233,12 +215,32 @@ def __init__( self.show_pbar = show_pbar self.leave_pbar = leave_pbar + self.iter_step = 0 + self.verbose = verbose + self.y_device = backend.get_device(y) if self.x is None: with self.y_device: self.x = self.y_device.xp.zeros(A.ishape, dtype=y.dtype) self.x_device = backend.get_device(self.x) + + # make sure the arrays are on the same device. + # If one of them is on GPU, send the another one also to GPU. + if self.y_device != self.x_device: + if self.y_device == backend.cpu_device: + y = backend.to_device(y, device=self.x_device) + self.y_device = backend.get_device(y) + elif self.x_device == backend.cpu_device: + self.x = backend.to_device(self.x, + device=backend.get_device(y)) + self.x_device = backend.get_device(self.x) + + assert self.y_device == self.x_device + + if self.z is not None: + self.z = backend.to_device(self.z, device=self.y_device) + self._get_alg() if self.save_objective_values: self.objective_values = [self.objective()] @@ -262,7 +264,7 @@ def _summarize(self): ) def _output(self): - return self.x + return self.x * self.scale def _get_alg(self): if self.solver is None: @@ -432,6 +434,9 @@ def _get_ADMM(self): \frac{\lambda}{2} \| x - z \|_2^2 + g(v) """ + ABSTOL = 1E-4 + RELTOL = 1E-3 + xp = self.x_device.xp with self.x_device: if self.G is None: @@ -461,14 +466,16 @@ def minL_x(): AHA += self.rho * self.G.H * self.G - App( - ConjugateGradient( - AHA, AHy, self.x, P=self.P, max_iter=self.max_cg_iter - ), - show_pbar=False, - ).run() + App(ConjugateGradient(AHA, AHy, self.x, P=self.P, + max_iter=self.max_cg_iter, + verbose=self.verbose), + show_pbar=False).run() def minL_v(): + + self.iter_step += 1 + v_old = v.copy() + if self.G is None: backend.copyto(v, self.x + u) else: @@ -477,6 +484,28 @@ def minL_v(): if self.proxg is not None: backend.copyto(v, self.proxg(1 / self.rho, v)) + if self.verbose: + with self.x_device: + Gx = self.G(self.x) + + r_norm = xp.linalg.norm(Gx - v).item() + s_norm = xp.linalg.norm(-self.rho * (v - v_old)).item() + + r_scaling = max(xp.linalg.norm(Gx).item(), + xp.linalg.norm(v).item()) + s_scaling = self.rho * xp.linalg.norm(u).item() + + eps_pri = ABSTOL * (np.prod(v.shape)**0.5) \ + + RELTOL * r_scaling + eps_dual = ABSTOL * (np.prod(v.shape)**0.5) \ + + RELTOL * s_scaling + + print('admm iter: ' + "%2d" % (self.iter_step) + + ', r norm: ' + "%10.4f" % (r_norm) + + ', eps pri: ' + "%10.4f" % (eps_pri) + + ', s norm: ' + "%10.4f" % (s_norm) + + ', eps dual: ' + "%10.4f" % (eps_dual)) + I_v = linop.Identity(v.shape) if self.G is None: I_x = linop.Identity(self.x.shape) @@ -523,6 +552,143 @@ def objective(self): return obj +class NonLinearLeastSquares(App): + r"""Non-linear least squares application. + + Solves for the following problem, with optional regularizations: + + .. math: + \min_x \frac{1}{2} \| A x - y \|_2^2 + \alpha G x + + A is a non-linear operator. + \alpha is the regularization strength. + G is a regulariztion term on x, + e.g. G x := \| T x \|_1, with T being total variation. + + In this case, the non-linear problem can be solved via + Alternating Direction Method of Multipliers (ADMM): + + (1) x^{(k+1)} := argmin_x \| A x -y \|_2^2 + + \rho/2 \| T x - z^{(k)} + + u^{(k)} \|_2^2 + (2) z^{(k+1)} := \Tau_{\alpha/\rho} (T x^{(k+1)} + u^{(k)}) + (3) u^{(k+1)} := u^{(k)} + T x^{(k+1)} - z^{(k+1)} + + Args: + A (Nlop): Forward non-linear operator. + y (array): Observation. e.g. k-space data. + x (array): Solution. + x0 (array): bias for L2 regularization. + max_iter (int): Maximum number of iterations. + lamda (float): regularization strength. + redu (float): reduction factor along iterations. + gn_iter (int): number of Gauss-Newton iterations. + inner_iter (int): number of inner iterations. + inner_tol (float): tolerance for the inner iterations. + G (None or Linop): Regularization linear operator. + proxg (None or prox): Proximal operator. + + Author: + Zhengguo Tan + """ + def __init__(self, A, y, x=None, x0=None, + max_iter=6, lamda=1, redu=2, + gn_iter=4, inner_iter=100, inner_tol=0.01, + G=None, proxg=None, + show_pbar=True, leave_pbar=True, verbose=False): + self.A = A + self.y = y + + self.device = backend.get_device(y) + + self.x = self._array_to_device(x, A.ishape, y.dtype) + self.x0 = self._array_to_device(x0, A.ishape, y.dtype) + + self.max_iter = max_iter + self.lamda = lamda + self.redu = redu + self.gn_iter = gn_iter + self.inner_iter = inner_iter + self.inner_tol = inner_tol + + self.G = G + self.proxg = proxg + + self.show_pbar = show_pbar + self.leave_pbar = leave_pbar + self.verbose = verbose + + self._get_alg() + + super().__init__(self.alg, show_pbar=show_pbar, + leave_pbar=leave_pbar) + + def _output(self): + return self.x + + def _array_to_device(self, arr, shape, dtype): + xp = self.device.xp + + with self.device: + if arr is None: + arr_on_device = xp.zeros(shape, dtype=dtype) + else: + arr_on_device = backend.to_device(arr, device=self.device) + + return arr_on_device + + def _get_alg(self): + if self.proxg is None: + # L2 regularization + # TODO: incorporate G + self.alg = IRGNM(self.A, self.y, self.x, + x0=self.x0, max_iter=self.gn_iter, + alpha=self.lamda, redu=self.redu, + inner_iter=self.inner_iter, + inner_tol=self.inner_tol, + verbose=self.verbose) + + else: + xp = self.device.xp + with self.device: + if self.G is None: + v = self.x.copy() + else: + v = self.G(self.x) + + u = xp.zeros_like(v) + + def _minL_x(): + if self.G is None: + x0 = v - u + else: + x0 = self.G.H(v - u) + + App(IRGNM(self.A, self.y, self.x, x0=x0, + max_iter=self.gn_iter, + alpha=self.lamda, redu=self.redu, + inner_iter=self.inner_iter, + inner_tol=self.inner_tol, + verbose=self.verbose), show_pbar=False).run() + + def _minL_v(): + if self.G is None: + backend.copyto(v, self.x + u) + else: + backend.copyto(v, self.G(self.x) + u) + + backend.copyto(v, self.proxg(1 / self.lamda, v)) + + I_v = linop.Identity(v.shape) + if self.G is None: + G = linop.Identity(self.x.shape) + else: + G = self.G + + self.alg = ADMM(_minL_x, _minL_v, self.x, v, u, + G, -I_v, 0, max_iter=self.max_iter) + + class L2ConstrainedMinimization(App): r"""L2 contrained minimization application. diff --git a/sigpy/config.py b/sigpy/config.py index 3a3858ff..8a8c393a 100644 --- a/sigpy/config.py +++ b/sigpy/config.py @@ -6,8 +6,10 @@ """ import warnings from importlib import util +import torch -cupy_enabled = util.find_spec("cupy") is not None +cuda_avail = torch.cuda.is_available() +cupy_enabled = (util.find_spec("cupy") is not None) and cuda_avail if cupy_enabled: try: import cupy # noqa diff --git a/sigpy/coord.py b/sigpy/coord.py new file mode 100644 index 00000000..b48caec4 --- /dev/null +++ b/sigpy/coord.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +"""Methods for Coordinate Conversion + +Reference: + https://math.libretexts.org/Bookshelves/Calculus/Book%3A_Calculus_(OpenStax)/12%3A_Vectors_in_Space/12.07%3A_Cylindrical_and_Spherical_Coordinates#:~:text=To%20convert%20a%20point%20from,y2%2Bz2). + +Author: Zhengguo Tan +""" + +import numpy as np + + +def cartes_to_spheri(x, y, z): + + r = (x**2 + y**2 + z**2)**0.5 + theta = np.arctan2(y, x) + phi = np.arccos(z / r) + + return r, theta, phi + + +def spheri_to_cartes(r, theta, phi): + + x = r * np.sin(phi) * np.cos(theta) + y = r * np.sin(phi) * np.sin(theta) + z = r * np.cos(phi) + + return x, y, z diff --git a/sigpy/extract_dicom.py b/sigpy/extract_dicom.py new file mode 100644 index 00000000..76f0843d --- /dev/null +++ b/sigpy/extract_dicom.py @@ -0,0 +1,132 @@ +""" +Methods for DICOM loading, extraction and saving. + +Author: Zhengguo Tan +""" + +# %% +import numpy as np +import os +import pydicom + +import os.path +import sys +sys.path.insert(0, os.path.join(os.environ['BART_PATH'], 'python')) +import cfl + +from optparse import OptionParser + +# __all__ = ['extract_ep2d_diff'] + + +def save_dicom_img(dcm_file, out_file, file_type='cfl'): + """Save dicom file as image. + + Args: + dcm_file (string): dicom file name (including the directory). + file_type (string): which format to save ('cfl' or 'npy'). + """ + ds = pydicom.filereader.dcmread(dcm_file) + + img = ds.pixel_array + + print('> write to ' + out_file) + + if file_type == 'cfl': + cfl.writecfl(out_file, img) + elif file_type == 'npy': + np.save(out_file, img) + + return None + + +def ep2d_diff(dcm_dir, start=1, incre=1, + save=False, save_dir=os.getcwd(), save_file='test', + save_type='cfl'): + """Extract b values and g vectors from dicom headers. + + Args: + dcm_dir (string): directories where dicom files are stored. + start (int): starting number. + incre (int): incremental number (step). + + Returns: + b (array): b-value array. + g (array): g-vectory array. + """ + + # initialization + b = np.array([], dtype=np.float).reshape(0, 1) + g = np.array([], dtype=np.float).reshape(0, 3) + + dcm_files = sorted(os.listdir(dcm_dir)) + stop = len(dcm_files) + + for index in np.arange(start=start, stop=stop+1, step=incre): + dcm_file = dcm_files[index-1] + # print(dcm_file) + dcm_dir_file = os.path.join(dcm_dir, dcm_file) + + assert dcm_file.find(str(index).zfill(4)) >= 0 + assert os.path.isfile(dcm_dir_file) == True + + ds = pydicom.filereader.dcmread(dcm_dir_file) + + if save: + img = ds.pixel_array + if index == start: + out = np.reshape(img, list(img.shape) + [1]) + print(out.shape) + else: + out = np.concatenate((out, + np.reshape(img, list(img.shape) + [1])), + axis=2) + # save_dicom_img(dcm_dir_file, dcm_dir_file) + + # b values (b) + element = ds[0x0019, 0x100C] + assert element.name.find('B_value') >= 0 + bval = float(element.value) + b = np.vstack([b, bval]) + + # diffusion gradient directions (g) + # for b-value equals 0, the dicom tag for g does not exist! + if bval == 0: + g = np.vstack([g, [0., 0., 0.]]) + else: + element = ds[0x0019, 0x100E] + assert element.name.find('DiffusionGradientDirection') >= 0 + g = np.vstack([g, element.value]) + + if save: + save_name = save_dir + '/' + save_file + if save_type == 'cfl': + cfl.writecfl(save_name, out) + elif save_type == 'npy': + np.save(save_name, out) + + return b, g + + +# %% +if __name__ == "__main__": + usage = "%prog [options] " + + parser = OptionParser(description="extract b-values (b) diffusion-gradient-directions (g) from dicom files", usage=usage) + + parser.add_option("--start", dest="file_start", + help="dicom file starting number", default=1) + parser.add_option("--incre", dest="file_incre", + help="dicom file incremental number", default=1) + + (options, args) = parser.parse_args() + + file_start = int(options.file_start) + file_incre = int(options.file_incre) + + dir = str(args[0]) + + b, g = ep2d_diff(dir, start=file_start, incre=file_incre) + + np.save(str(args[1]), b) + np.save(str(args[2]), g) diff --git a/sigpy/linop.py b/sigpy/linop.py index 98fa23a4..2f3ee47d 100644 --- a/sigpy/linop.py +++ b/sigpy/linop.py @@ -6,8 +6,7 @@ """ import numpy as np -from sigpy import backend, block, conv, fourier, interp, util, wavelet - +from sigpy import backend, block, fourier, util, interp, conv, wavelet, nlop def _check_shape_positive(shape): if not all(s > 0 for s in shape): @@ -153,6 +152,8 @@ def __call__(self, input): def __mul__(self, input): if isinstance(input, Linop): return Compose([self, input]) + elif isinstance(input, nlop.Nlop): + return nlop.Compose([self, input]) elif np.isscalar(input): M = Multiply(self.ishape, input) return Compose([self, M]) @@ -1378,11 +1379,20 @@ class Sum(Linop): Args: ishape (tuple of ints): Input shape. axes (tuple of ints): Axes to sum over. + keepdims (boolean): Keep the summed dims (axes). """ - def __init__(self, ishape, axes): + def __init__(self, ishape, axes, keepdims=False): self.axes = tuple(i % len(ishape) for i in axes) - oshape = [ishape[i] for i in range(len(ishape)) if i not in self.axes] + + self.keepdims = keepdims + + if self.keepdims is False: + oshape = [ishape[i] for i in range(len(ishape)) if i not in self.axes] + else: + oshape = [ishape[i] for i in range(len(ishape))] + for i in self.axes: + oshape[i] = 1 super().__init__(oshape, ishape) @@ -1390,10 +1400,10 @@ def _apply(self, input): device = backend.get_device(input) xp = device.xp with device: - return xp.sum(input, axis=self.axes) + return xp.sum(input, axis=self.axes, keepdims=self.keepdims) def _adjoint_linop(self): - return Tile(self.ishape, self.axes) + return Tile(self.ishape, self.axes, keepdims=self.keepdims) class Tile(Linop): @@ -1402,12 +1412,22 @@ class Tile(Linop): Args: oshape (tuple of ints): Output shape. axes (tuple of ints): Axes to tile. + keepdims (boolean): Keep the summed dims (axes). """ - def __init__(self, oshape, axes): + def __init__(self, oshape, axes, keepdims=False): self.axes = tuple(a % len(oshape) for a in axes) - ishape = [oshape[d] for d in range(len(oshape)) if d not in self.axes] + + self.keepdims = keepdims + + if self.keepdims is False: + ishape = [oshape[d] for d in range(len(oshape)) if d not in self.axes] + else: + ishape = [oshape[d] for d in range(len(oshape))] + for i in self.axes: + ishape[i] = 1 + self.expanded_ishape = [] self.reps = [] for d in range(len(oshape)): @@ -1437,20 +1457,22 @@ class ArrayToBlocks(Linop): ishape (array): input array of shape [..., N_1, ..., N_D] blk_shape (tuple): block shape of length D, with D <= 4. blk_strides (tuple): block strides of length D. + mean (boolean): take the mean of repeated position. See Also: :func:`sigpy.block.array_to_blocks` """ - def __init__(self, ishape, blk_shape, blk_strides): + def __init__(self, ishape, blk_shape, blk_strides, mean=True): self.blk_shape = blk_shape self.blk_strides = blk_strides D = len(blk_shape) - num_blks = [ - (i - b + s) // s - for i, b, s in zip(ishape[-D:], blk_shape, blk_strides) - ] + num_blks = [(i - b + s) // s for i, b, + s in zip(ishape[-D:], blk_shape, blk_strides)] + self.num_blks = num_blks + self.mean = mean + oshape = list(ishape[:-D]) + num_blks + list(blk_shape) super().__init__(oshape, ishape) @@ -1463,7 +1485,8 @@ def _apply(self, input): ) def _adjoint_linop(self): - return BlocksToArray(self.ishape, self.blk_shape, self.blk_strides) + return BlocksToArray(self.ishape, self.blk_shape, self.blk_strides, + mean=self.mean) def _normal_linop(self): return Identity(self.ishape) @@ -1476,13 +1499,14 @@ class BlocksToArray(Linop): oshape (tuple): output shape. blk_shape (tuple): block shape of length D. blk_strides (tuple): block strides of length D. + mean (boolean): take the mean of repeated position. Returns: array: array of shape oshape. """ - def __init__(self, oshape, blk_shape, blk_strides): + def __init__(self, oshape, blk_shape, blk_strides, mean=True): self.blk_shape = blk_shape self.blk_strides = blk_strides D = len(blk_shape) @@ -1492,17 +1516,40 @@ def __init__(self, oshape, blk_shape, blk_strides): ] ishape = list(oshape[:-D]) + num_blks + list(blk_shape) + self.mean = mean + + # Inspired by: + # https://pytorch.org/docs/stable/generated/torch.nn.Fold.html#torch.nn.Fold + # Switch on this option such that: + # BlockToArrays(ArrayToBlocks(input, ...), ...) = input + # i.e. average duplicated values + divisor = np.ones(oshape) + if self.mean: + divisor = block.array_to_blocks(divisor, blk_shape, blk_strides) + divisor = block.blocks_to_array(divisor, oshape, + blk_shape, blk_strides) + + self.divisor = divisor + super().__init__(oshape, ishape) def _apply(self, input): device = backend.get_device(input) + xp = device.xp + + divisor = backend.to_device(self.divisor, device=device) + with device: - return block.blocks_to_array( - input, self.oshape, self.blk_shape, self.blk_strides - ) + output = block.blocks_to_array( + input, self.oshape, self.blk_shape, self.blk_strides) + + output = xp.where(divisor > 0, xp.divide(output, divisor), 0) + + return output def _adjoint_linop(self): - return ArrayToBlocks(self.oshape, self.blk_shape, self.blk_strides) + return ArrayToBlocks(self.oshape, self.blk_shape, self.blk_strides, + mean=self.mean) def _normal_linop(self): return Identity(self.ishape) @@ -1639,6 +1686,190 @@ def _adjoint_linop(self): ) +class HDNUFFT(NUFFT): + """High-dimensional (HD) NUFFT. + + Args: + ishape (tuple of int): Input shape. + coord (array): Coordinates, with values [-ishape / 2, ishape / 2] + oversamp (float): Oversampling factor. + width (float): Kernel width. + toeplitz (bool): Use toeplitz PSF to evaluate normal operator. + nr_hd (int): number of higher dimensions (HD). + HD starts from left most to right in ishape and coord. + Default 1. + When nr_hd == 0, HDNUFFT is idential to NUFFT. + + Author: + Zhengguo Tan + """ + def __init__(self, ishape, coord, + oversamp=1.25, width=4, toeplitz=False, nr_hd=1, + use_dcf=False): + self.coord = coord + self.oversamp = oversamp + self.width = width + self.toeplitz = toeplitz + + if len(ishape) <= 3 and len(coord.shape) <= 3: + nr_hd = 0 + + self.nr_hd = nr_hd + self.use_dcf = use_dcf + + self._check_higher_dim(ishape, coord) + + ndim = coord.shape[-1] + cshape = coord.shape + excl_hd = len(cshape) - nr_hd + + oshape = list(ishape[:-ndim]) + list(cshape[-excl_hd:-1]) + + super(NUFFT, self).__init__(oshape, ishape) + + def _check_higher_dim(self, ishape, coord): + cshape = coord.shape + + for n in range(self.nr_hd): + if ishape[n] != cshape[n]: + raise ValueError( + 'shape mismatch for {s} between input {ishape} and coord {cshape}'.format( + s=self, ishape=ishape, cshape=coord.shape)) + + def _apply(self, input): + # call the parent method + # when this is not high-dimensional + if self.nr_hd == 0: + return super(HDNUFFT, self)._apply(input) + + sz_hd = np.prod(self.ishape[:self.nr_hd]) + output = np.zeros([sz_hd] + list(self.oshape[self.nr_hd:]), + dtype=complex) + + device = backend.get_device(input) + with device: + xp = device.xp + coord = backend.to_device(self.coord, device) + output = backend.to_device(output, device=device) + + coord = xp.reshape(coord, [sz_hd] + list(coord.shape[self.nr_hd:])) + input = xp.reshape(input, [sz_hd] + list(input.shape[self.nr_hd:])) + + ld_ishape = self.ishape[self.nr_hd:] + + for nhd in range(sz_hd): + ld_coord = coord[nhd, ...] + ld_input = input[nhd, ...] + + F = NUFFT(ld_ishape, ld_coord, + oversamp=self.oversamp, width=self.width, + toeplitz=self.toeplitz) + + output[nhd, ...] = F(ld_input) + + output = xp.reshape(output, self.oshape) + return output + + def _adjoint_linop(self): + return HDNUFFTAdjoint(self.ishape, self.coord, + oversamp=self.oversamp, width=self.width, + toeplitz=self.toeplitz, nr_hd=self.nr_hd, + use_dcf=self.use_dcf) + + def _normal_linop(self): + if self.toeplitz is False: + return self.H * self + + +class HDNUFFTAdjoint(NUFFTAdjoint): + """high-dimensional NUFFT adjoint linear operator. + + Args: + oshape (tuple of int): Output shape. + coord (array): Coordinates, with values [-ishape / 2, ishape / 2] + oversamp (float): Oversampling factor. + width (float): Kernel width. + toeplitz (bool): Use toeplitz PSF to evaluate normal operator. + nr_hd (int): number of higher dimensions (HD). + HD starts from left most to right in ishape and coord. Default 1. + When nr_hd == 0, HDNUFFT is idential to NUFFT. + + Author: + Zhengguo Tan + """ + def __init__(self, oshape, coord, + oversamp=1.25, width=4, toeplitz=False, nr_hd=1, + use_dcf=False): + self.coord = coord + self.oversamp = oversamp + self.width = width + self.toeplitz = toeplitz + + if len(oshape) <= 3 and len(coord.shape) <= 3: + nr_hd = 0 + + self.nr_hd = nr_hd + self.use_dcf = use_dcf + + ndim = coord.shape[-1] + cshape = coord.shape + excl_hd = len(cshape) - nr_hd + + ishape = list(oshape[:-ndim]) + list(cshape[-excl_hd:-1]) + + super(NUFFTAdjoint, self).__init__(oshape, ishape) + + def _apply(self, input): + # call the parent method + # when this is not high-dimensional + if self.nr_hd == 0: + return super(HDNUFFTAdjoint, self)._apply(input) + + sz_hd = np.prod(self.ishape[:self.nr_hd]) + output = np.zeros([sz_hd] + list(self.oshape[self.nr_hd:]), + dtype=complex) + + device = backend.get_device(input) + with device: + xp = device.xp + coord = backend.to_device(self.coord, device) + output = backend.to_device(output, device) + + coord = xp.reshape(coord, [sz_hd] + list(coord.shape[self.nr_hd:])) + input = xp.reshape(input, [sz_hd] + list(input.shape[self.nr_hd:])) + + if self.use_dcf is True: + from sigpy.mri import dcf + N_dim = coord.shape[-1] + dcf = dcf.pipe_menon_dcf(coord, + img_shape=self.oshape[-N_dim:], + device=device) + dcf_c = dcf.astype(input.dtype) + else: + dcf_c = xp.ones_like(input) + + ld_oshape = self.oshape[self.nr_hd:] + + for nhd in range(sz_hd): + ld_coord = coord[nhd, ...] + ld_input = input[nhd, ...] + ld_dcf_c = dcf_c[nhd, ...] + + F = NUFFTAdjoint(ld_oshape, ld_coord, + oversamp=self.oversamp, + width=self.width) + + output[nhd, ...] = F(ld_input * ld_dcf_c) + + output = xp.reshape(output, self.oshape) + return output + + def _adjoint_linop(self): + return HDNUFFT(self.oshape, self.coord, + oversamp=self.oversamp, width=self.width, + toeplitz=self.toeplitz, nr_hd=self.nr_hd) + + class ConvolveData(Linop): r"""Convolution operator for data arrays. @@ -1886,7 +2117,9 @@ def __init__(self, ishape, idx): super().__init__(oshape, ishape) def _apply(self, input): - return input[self.idx] + device = backend.get_device(input) + with device: + return input[self.idx] def _adjoint_linop(self): return Embed(self.ishape, self.idx) @@ -1910,9 +2143,91 @@ def __init__(self, oshape, idx): super().__init__(oshape, ishape) def _apply(self, input): - output = np.zeros(self.oshape, dtype=input.dtype) - output[self.idx] = input + device = backend.get_device(input) + xp = device.xp + + with device: + output = xp.zeros(self.oshape, dtype=input.dtype) + output[self.idx] = input + return output def _adjoint_linop(self): return Slice(self.oshape, self.idx) + + +class Sobolev(Linop): + """Sobolev weight + + Given input in k-space, + returns output as: + + output = ifft(W(input)) + + W = _get_Sobolev_weight(W_shape) + = ( 1 + a*|k|^2 )^(-b) + + where k is the meshgrid of the ishape, normalized to [-0.5, 0.5]. + + Args: + ishape (tuple of ints): input shape + a (int): a in Sobolev weight [default: 440] + b (int): b in Sobolev weight [default: 16] + """ + def __init__(self, ishape, a=440, b=16): + wshape = ishape[-2:] + self.W = self._get_Sobolev_weight(wshape, a, b) + self.S = Multiply(ishape, self.W) + self.F = IFFT(ishape, axes=range(-2, 0)) + super().__init__(ishape, ishape) + + def _get_Sobolev_weight(self, weight_shape, a=440, b=16): + _check_shape_positive(weight_shape) + + W = np.zeros(shape=weight_shape, dtype=complex) + + NY = weight_shape[1] + NX = weight_shape[0] + + for y in range(0, NY): + for x in range(0, NX): + dist = pow(x/NX - 0.5, 2.) + pow(y/NY - 0.5, 2.) + W[x, y] = 1./pow(1 + a*dist, b) + 0j + + return W + + def _apply(self, input): + device = backend.get_device(input) + with device: + return self.F * self.S * input + + def _adjoint_linop(self): + return (self.F * self.S).H + + +class RealValueConstraint(Linop): + """real-value constraint linear operator. + + Returns real-value input directly. + + Args: + shape (tuple of ints): Input shape + + """ + + def __init__(self, shape): + super().__init__(shape, shape) + + def _apply(self, input): + device = backend.get_device(input) + xp = device.xp + with device: + output = xp.real(input).astype(complex) + + return output + + def _adjoint_linop(self): + return self + + def _normal_linop(self): + return self \ No newline at end of file diff --git a/sigpy/mri/app.py b/sigpy/mri/app.py index 88b9acbd..752465a9 100644 --- a/sigpy/mri/app.py +++ b/sigpy/mri/app.py @@ -2,23 +2,26 @@ """MRI applications. """ import numpy as np - import sigpy as sp -from sigpy.mri import linop -__all__ = [ - "SenseRecon", - "L1WaveletRecon", - "TotalVariationRecon", - "JsenseRecon", - "EspiritCalib", -] +from sigpy.mri import linop, nlop +from sigpy.mri.dims import * + +__all__ = ['SenseRecon', 'L1WaveletRecon', 'TotalVariationRecon', + 'JsenseRecon', 'EspiritCalib', 'HighDimensionalRecon', + 'ModelDiffRecon'] -def _estimate_weights(y, weights, coord): +def _estimate_weights(y, weights, coord, coil_dim=0): if weights is None and coord is None: with sp.get_device(y): - weights = (sp.rss(y, axes=(0,)) > 0).astype(y.dtype) + weights = (sp.rss(y, axes=(coil_dim, ), + keepdims=True) > 0).astype(y.dtype) + + if weights is None and coord is not None: + with sp.get_device(y): + weights = (sp.rss(y, axes=(coil_dim, ), + keepdims=True) > 0).astype(y.dtype) return weights @@ -586,3 +589,434 @@ def _output(self): return mps, max_eig else: return mps + + +def _get_regularization(ishape, regu='TIK', lamda=0, + regu_kspace=False, + regu_axes=[-2, -1], + blk_shape=(8, 8), + blk_strides=(8, 8), + normalization=False, + thresh='soft', + deep_model=None, + ro_extend_fold=1): + """ + This function constructs regularization terms. + + Author: + Zhengguo Tan + """ + + idx = [] + for n in range(-len(ishape), -1, 1): + idx.append(slice(0, ishape[n], 1)) + + Nx_ext = ishape[-1] + Nx = Nx_ext // ro_extend_fold + + Slices = [] + for n in range(ro_extend_fold): + idn = idx.copy() + idn.append(slice(n * Nx, (n+1) * Nx, 1)) + + Slices.append(sp.linop.Slice(ishape, tuple(idn))) + + trafos = sp.linop.Vstack(Slices, axis=0) + + + if regu_kspace is True: + trafos = sp.linop.FFT(trafos.oshape, axes=[-2, -1]) * trafos + else: + trafos = sp.linop.Identity(trafos.oshape) * trafos + + + if regu == 'TIK': + + proxg = None + strength = lamda + + elif regu == 'LLR': + + blk_shape = list(blk_shape) + blk_strides = list(blk_strides) + + for ind in range(len(blk_shape)): + if ishape[ind-2] <= blk_shape[ind-2]: + blk_shape[ind-2] = ishape[ind-2] // 3 + blk_strides[ind-2] = ishape[ind-2] // 3 + + proxg = sp.prox.LLRL1Reg(trafos.oshape, lamda, + blk_shape=blk_shape, + blk_strides=blk_strides, + normalization=normalization) + + # the lamda is passed into prox, + # so no need for strength here. + strength = 0 + + elif regu == 'SLR': + + blk_shape = list(blk_shape) + blk_strides = list(blk_strides) + + proxg = sp.prox.SLRMCReg(trafos.oshape, lamda, + blk_shape=blk_shape, + blk_strides=blk_strides, + thresh=thresh, verbose=True) + + strength = 0 + + elif regu == 'TV': + + T = sp.linop.FiniteDifference(trafos.oshape, axes=regu_axes) + trafos = T * trafos + + proxg = sp.prox.L1Reg(trafos.oshape, lamda) + strength = 0 + + elif (regu == 'DAE') and (deep_model is not None): + + proxg = sp.prox.DAEReg(trafos.oshape, lamda, deep_model) + strength = 0 + + g = None + + # TODO: It is difficult to define g here, because + # proxg functions like SLR and LLR contain linear transformations + # inside their prox implementation. + + # if proxg is None: + # g = None + # else: + # def g(input): + # device = sp.get_device(input) + # xp = device.xp + # with device: + # return lamda * xp.sum(xp.abs(input)).item() + + return trafos, proxg, g, strength + + +class HighDimensionalRecon(sp.app.LinearLeastSquares): + r"""High-Dimensional MRI Reconstruction. + + Considers the problem + + .. math:: + \min_x \frac{1}{2} \| P F S M x - y \|_2^2 + + \frac{\lambda}{2} R(x) + + where P is the sampling operator, + F is the Fourier transform operator, + S is the SENSE operator (multiplication with coil sensitivity maps), + M is the modeling operator, which can be + - subspace matrix, + - phase matrix, + - etc. + x is the subspace coefficient maps, and + y is the k-space measurements. + + R(x) is the regularization term. + (1) TIK: Tikhonov L2. + R(x) = \| x \|_2^2 + (2) LLR: Locally Low Rank. + R(x) = \| G(x) \|_1, + where G is the transformation function. + (3) SLR: Structured Low Rank. + (4) TV: total variation. + + Args: + y (array): measured k-space data. + mps (array): coil sensitivity maps. + + Author: + Zhengguo Tan + + References: + * Lustig M, Donoho DL, Pauly JM. + Sparse MRI: The application of compressed sensing for rapid MRI. + Magn Reson Med 2007;58:1182-1195. + + * Block KT, Uecker M, Frahm J. + Undersampled radial MRI with multiple coils. + Iterative image reconstruction using a total variation constraint. + Magn Reson Med 2007;57:1086-1098. + + * Lustig M, Donoho DL, Santos JM, Pauly JM. + Compressed sensing MRI. + IEEE Signal Process Mag 2008;25:72-82. + """ + def __init__(self, y, mps, lamda=0, + weights=None, coord=None, + use_dcf=False, + basis=None, + phase_echo=None, combine_echo=True, + phase_sms=None, + scale=0, regu='TIK', regu_kspace=False, + regu_axes=[-2, -1], x=None, + blk_shape=(8, 8), blk_strides=(8, 8), + normalization=False, + deep_model=None, + thresh='soft', max_iter=50, + ro_extend_fold=1, solver=None, + device=sp.cpu_device, show_pbar=True, + **kwargs): + + # k-space data in accordance with sigpy/mri/dims.py + Ntime, Necho, Ncoil, Nz = y.shape[:-2] + Ny, Nx = mps.shape[-2:] + + assert(1 == Nz) # deal with collapsed y even for SMS + + if phase_sms is not None: + MB = phase_sms.shape[DIM_Z] + else: + MB = 1 + + # start to construct image shape + img_shape = [1] + [MB] + [Ny] + [Nx] + + # %% construct MRI forward model + + #### case 1. subspace modeling + if basis is not None: + + Ncontrast, Ncoef = basis.shape + assert(Ncontrast == Ntime * Necho) + + ishape = [Ncoef] + [1] + img_shape + + sub_ishape = [Ncoef] + [np.prod(ishape[1:])] + sub_oshape = [Ntime] + [Necho] + img_shape + + B1 = sp.linop.Reshape(sub_ishape, ishape) + B2 = sp.linop.MatMul(B1.oshape, basis) + B3 = sp.linop.Reshape(sub_oshape, B2.oshape) + + B = B3 * B2 * B1 + + else: + + if combine_echo is True: + + assert(phase_echo is not None) + ishape = [Ntime] + [1] + img_shape + + else: + + ishape = [Ntime] + [Necho] + img_shape + + B = sp.linop.Identity(ishape) + + + #### case 2. echo phase modeling + if phase_echo is not None: + + self._check_two_shape([Ntime] + [Necho] + img_shape, phase_echo.shape) + + P = sp.linop.Multiply(B.oshape, phase_echo) + + else: + + P = sp.linop.Identity(B.oshape) + + + #### parallel imaging modeling + + # only one set of coil sensitivity maps for all images + assert(y.shape[DIM_COIL] == mps.shape[DIM_COIL]) + + S = sp.linop.Multiply(P.oshape, mps) + + + # FFT + if coord is None: + self._check_two_shape(list(y.shape[DIM_Y:]), mps.shape[DIM_Y:]) + F = sp.linop.FFT(S.oshape, axes=range(-2, 0)) + + else: + F = sp.linop.HDNUFFT(S.oshape, coord, use_dcf=use_dcf) + + + # SMS + if phase_sms is not None: + + self._check_two_shape(list(phase_sms.shape), mps.shape[DIM_Z:]) + + PHI = sp.linop.Multiply(F.oshape, phase_sms) + SUM = sp.linop.Sum(PHI.oshape, axes=(DIM_Z, ), keepdims=True) + + M = SUM * PHI + + else: + + M = sp.linop.Identity(F.oshape) + + + # compute k-space sampling mask + if weights is None: + weights = _estimate_weights(y, weights, coord, coil_dim=DIM_COIL) + + y = sp.to_device(y * weights**0.5, device=device) + + W = sp.linop.Multiply(M.oshape, weights**0.5) + + + #### chain models + A = W * M * F * S * P * B + + + # %% scale y + if scale == 0: + device = sp.get_device(y) + xp = device.xp + with device: + x0 = A.H(y) + x1 = np.linalg.norm(x0, axis=0) + x2 = np.linalg.norm(x1, axis=0) + x1 = xp.sort(x2.flatten()) + x_med = abs(x1[len(x1)//2]) + x_p90 = abs(x1[int(len(x1) * 0.9)]) + x_max = abs(x1[-1]) + if (x_max - x_p90) < 2 * (x_p90 - x_med): + scale = x_p90 + else: + scale = x_max + + # print('scale: ' + str(scale)) + y /= scale + + + # %% initialization + if x is not None: + x = sp.to_device(x, device=device) + + + # %% regularization + trafos, proxg, g, strength = _get_regularization( + A.ishape, regu=regu, lamda=lamda, + regu_kspace=regu_kspace, regu_axes=regu_axes, + blk_shape=blk_shape, blk_strides=blk_strides, + normalization=normalization, + thresh=thresh, deep_model=deep_model, + ro_extend_fold=ro_extend_fold) + + if solver == 'ADMM': + show_pbar = False + + super().__init__(A, y, x=x, lamda=strength, + G=trafos, proxg=proxg, g=g, scale=scale, + max_iter=max_iter, solver=solver, + show_pbar=show_pbar, **kwargs) + + + def _check_two_shape(self, ref_shape, dst_shape): + for i1, i2 in zip(ref_shape, dst_shape): + if (i1 != i2): + raise ValueError('shape mismatch for ref {ref}, got {dst}'.format( + ref=ref_shape, dst=dst_shape)) + + +class ModelDiffRecon(sp.app.NonLinearLeastSquares): + r"""Model-based Diffusion Reconstruction. + + Consider the problem + + .. math:: + \min_x \frac{1}{2} \| P F S B x - y \|_2^2 + + \frac{\lambda}{2} R(x) + + where P is the sampling operator, + F is the Fourier transform operator, + S is the SENSE operator (multiplication with coil sensitivity maps), + B is the non-linear diffusion model, + x is the diffusion model parameters, [b0, DTI or DKI], and + y is the k-space measurements. + + Therefore, The operation B x outputs + non-diffusion-weighted image (b0) and + diffusion-weighted images (DWI). + + Args: + y (array): Observation. + image_shape (list): Shape of Solution. + coil (array): Coil sensitivity maps. + coil_dim (int): Dimension of coil sensitivity in y. + coord (array): Sampling trajectory. + x (array): Solution. + x0 (array): Bias for regularization. + dwi_phase (array): Phase of diffusion weighted images. + weights (array): Sampling pattern. + encode_matrix (array): Diffusion encoding matrix (B). + max_iter (int): Maximum number of iterations. + lamda (float): Regularization strength (\rho in ADMM). + redu (float): Reduction factor along iterations. + gn_iter (int): Gauss-Newton iterations. + inner_iter (int): Inner iterations. + inner_tol (float): Inner tolerance. + G (None or Linop): Regularization linear operator. + proxg (None or Prox): Proximal operator for regularization. + device (Device): Use CPU or GPU. + + Author: + Zhengguo Tan + + References: + Welsh C. L., DiBella E. V. R., Adluru G., Hsu E. W. (2013). + Model-based reconstruction of undersampled diffusion tensor k-space data. + Magn. Reson. Med., 70, 429-440. + + Knoll F., Raya J. G., Halloran R. O., Baete S., Sigmund E., Bammer R., Block K. T., Otazo R., Sodickson D. K. (2015). + A model-based reconstruction for undersampled radial spin-echo DTI with variational penalties on the diffusion tensor. + Magn. Reson. Med., 28, 353-366. + + Dong Z., Dai E., Wang F., Zhang Z., Ma X., Yuan C., Guo H. (2018). + Model-based reconstruction for simultaneous multislice and parallel imaging accelerated multishot DTI. + Med. Phys., 45, 3196-3204. + + """ + def __init__(self, y, image_shape, + coil=None, coil_dim=-3, coord=None, + x=None, x0=None, + dwi_phase=None, weights=None, + encode_matrix=None, + max_iter=6, lamda=1, redu=2, + gn_iter=6, inner_iter=100, inner_tol=0.01, + G=None, proxg=None, + device=sp.cpu_device, + **kwargs): + + y = sp.to_device(y, device=device) + + xp = device.xp + + with device: + weights = _estimate_weights(y, weights, coord, coil_dim=coil_dim) + + A = nlop.Diffusion(image_shape, encode_matrix, coil, + dwi_phase=dwi_phase, weights=weights) + + if x is None: + with device: + x_b0 = xp.ones([1] + list(image_shape[1:]), dtype=y.dtype) * 1E-5 + x_D = xp.zeros([image_shape[0]-1] + list(image_shape[1:]), dtype=y.dtype) + x = xp.concatenate((x_b0, x_D)) + else: + with device: + x = sp.to_device(x, device=device) + + if x0 is None: + with device: + x0 = 0.9 * x + else: + with device: + x0 = sp.to_device(x0, device=device) + + super().__init__(A, y, x=x, x0=x0, + max_iter=max_iter, + lamda=lamda, redu=redu, + gn_iter=gn_iter, + inner_iter=inner_iter, + inner_tol=inner_tol, + G=G, proxg=proxg, + **kwargs) diff --git a/sigpy/mri/cc.py b/sigpy/mri/cc.py new file mode 100644 index 00000000..66200d31 --- /dev/null +++ b/sigpy/mri/cc.py @@ -0,0 +1,69 @@ +"""Coil compression functions. + +Author: Zhengguo Tan +""" +import sigpy as sp + +__all__ = ['scc', 'gcc'] + + +def scc(kdat, P=10, coil_dim=-2, device=sp.cpu_device): + r"""Coil compression based on SVD. + + Args: + kdat (array): raw k-space data. + P (int): number of virtual coils to be kept. [Default: 10]. + coil_dim (int): the coil dimension. [Default: -2]. + device: use CPU or GPU device. + + Returns: + coil compressed k-space data, and + truncated eigen vectors. + + References: + * Buehrer M., Pruessmann K. P., Boesiger P., Kozerke S. (2007). + Array compression for MRI with large coil arrays. + Magn. Reson. Med., 2007. + + * Huang F., Vijayakumar S., Li Y., Hertel S., Duensing G. R. (2008). + A software channel compression technique for faster reconstruction with many channels. + Magn. Reson. Imaging., 26, 133-141. + """ + if P >= kdat.shape[coil_dim]: + print('> return the original data') + return kdat, None + + device = sp.Device(device) + xp = device.xp + + with device: + # move the dimension of coils to 0 + y1 = xp.swapaxes(sp.to_device(kdat, device=device), coil_dim, 0) + y2 = xp.reshape(y1, (y1.shape[0], -1)) + + # covariance matrix: [num_coil, num_coil] + yc = xp.cov(y2) + + eigvals, eigvecs = xp.linalg.eigh(yc) + + # eigvals and eigvecs in descending order + idx = eigvals.argsort()[::-1] + eigvals = eigvals[idx] + eigvecs = eigvecs[:, idx] + + print('> energy kept: ' + + '%3.4f'%(xp.sum(eigvals[:P]) / xp.sum(eigvals))) + + S = eigvecs[:, :P] + + y3 = xp.conj(S.T) @ y2 + + y4 = xp.reshape(y3, [P] + list(y1.shape[1:])) + y5 = xp.swapaxes(y4, coil_dim, 0) + + return sp.to_device(y5, device=sp.get_device(kdat)), S + + +# TODO: geometric coil compression +def gcc(): + None diff --git a/sigpy/mri/dce.py b/sigpy/mri/dce.py new file mode 100644 index 00000000..6e086d59 --- /dev/null +++ b/sigpy/mri/dce.py @@ -0,0 +1,263 @@ +"""Functions for Dynamic Contrast Enhanced (DCE) MRI + +Author: Zhengguo Tan +""" +import numpy as np +import sigpy as sp + +from sigpy import backend + + +def arterial_input_function(sample_time, + A = [0.7668, 0.3309], + T = [0.1710, 0.4079], + sigma = [0.0583, 0.1359], + alpha = 1.0259, + beta = 0.1608, + s = 22.1275, + tau = 0.503, + Hct = 0.4): + """ + Args: + sample_time (array): sampling time array for AIF calculation. [unit: minutes] + + Please refer to the following references for the definition of other parameters. + + References: + * Parker GJM, Roberts C, Macdonald A, Buonaccorsi GA, Cheung S, Buckley DL, Jackson A, Watson Y, Davies K, Jayson GC. + Experimentally-derived functional form for a population-averaged high-temporal-resolution arterial input function for dynamic contrast-enhanced MRI. + Magnetic Resonance in Medicine 56:993-1000 (2006). + + * Tofts PS, Berkowitz B, Schnall MD. + Quantitative analysis of dynamic Gd-DTPA enhancement in breast tumors using a permeability model. + Magnetic Resonance in Medicine 33:564-568 (1995). + + * https://mriquestions.com/uploads/3/4/5/7/34572113/dce-mri_siemens.pdf + """ + + sample_mask = sample_time > 0 + + Cp = np.zeros_like(sample_time) + + # sigmoid function + exp_b = np.exp(-beta * sample_time) + exp_s = np.exp(-s * (sample_time - tau)) + sigmoid_vals = alpha * exp_b / (1 + exp_s) + + # Gaussian functions + for n in range(len(A)): + scale = A[n] / (sigma[n] * (2*np.pi)**0.5) + exp_t = np.exp(-((sample_time - T[n])**2.)/(2 * sigma[n]**2)) + Cp += scale * exp_t + + Cp += sigmoid_vals + + Cp *= sample_mask + + # Cp /= 3 # scaling + + # Cp /= (1 - Hct) + + return Cp + + +def Patlak(ishape, sample_time, device=sp.cpu_device): + """ + Args: + ishape (tuple or list of int): input parameter maps shape [Np, Ny, Nx]. + sample_time (array): sampling time array of the AIF. + + Output: + linop of C_time: contrast agent signal along time. + mult (array): + + References: + * Patlak C, Blasberg RG, Fenstermacher JD. + Graphical Evaluation of Blood-to-Brain Transfer Constants from Multiple-Time Uptake Data. + Journal of Cerebral Blood Flow & Metabolism (1983). + """ + + Np = ishape[0] + Ny, Nx = ishape[-2:] + # Np, Ny, Nx = ishape # parameter shape (K, V).^T + assert(2 == Np) + + sample_time = np.squeeze(sample_time) + assert(1 == sample_time.ndim) + + t0_idx = np.nonzero(sample_time == 0) + dt0 = sample_time[t0_idx] + + t1_idx = np.nonzero(sample_time) + dt1 = np.diff(sample_time[t1_idx], prepend=0) + + dt = np.concatenate((dt0, dt1)) + + Cp = arterial_input_function(sample_time) + K_time = np.cumsum(Cp) * dt[19] + mult = np.array([K_time, Cp]).T + + mult_dev = sp.to_device(mult, device=device) + + R = sp.linop.Reshape([Np, Ny*Nx], ishape) + M = sp.linop.MatMul(R.oshape, mult_dev) + B = sp.linop.Reshape([len(sample_time), 1, 1, 1, Ny, Nx], M.oshape) + + return B * M * R, mult_dev + + +def _array_to_device(input, device=sp.cpu_device): + + if isinstance(input, backend.get_array_module(input).ndarray): + output = backend.to_device(input, device=device) + elif np.isscalar(input): + output = input + + return output + + +class DCE(sp.nlop.Nlop): + """Tracer Kinetic Modeling + + This non-linear operator maps CA (contrast agent concentration) to + acquired MR image. i.e. + + input: DCE parameters [Np, 1, 1, Ny, Nx] + --> + CA [Ntime, 1, 1, 1, Ny, Nx] + --> + output: MR image of the same shape as CA. + + Args: + ishape (tuple of list of int): input shape. + sample_time (array): sampling time array of the AIF. + R1 (float scalar or array): inverse of baseline T1 values [default: 1]. + M0 (float scalar or array): baseline M0 magnetization [default: 5]. + R1CA (float scalar or array): inverse of CA T1 values [default: 4.39]. + FA (float): flip angle in degree [default: 15]. + TR (float): repetition time in second [default 0.006]. + + References: + * Guo Y, Lingala SG, Zhu Y, Lebel RM, Nayak KS. + Direct estimation of tracer-kinetic parameter maps from highly undersampled brain dynamic contrast enhanced MRI. + Magnetic Resonance in Medicine 78:1566-1578 (2017). + """ + + def __init__(self, ishape, + sample_time, + R1 = 1, + M0 = 5, + R1CA = 4.39, + FA = 15, + TR = 0.006, # second + rvc = True, + verbose = False, + device = sp.cpu_device, + repr_str = None): + + Np = ishape[-5] + Nz, Ny, Nx = ishape[-3:] + assert(1 == ishape[-4]) + + + sample_time = np.squeeze(sample_time) + + P, mult = Patlak(ishape, sample_time) + + self.Patlak = P + + oshape = P.oshape + + xp = device.xp + with device: + + self.mult = _array_to_device(mult, device=device) + + self.R1 = _array_to_device(R1, device=device) + self.M0 = _array_to_device(M0, device=device) + self.R1CA = _array_to_device(R1CA, device=device) + + FA_radian = FA * xp.pi / 180 + + self.M0_trans = self.M0 * xp.sin(FA_radian) + + E1 = xp.exp(-TR * self.R1) + self.M_steady = self.M0_trans * (1 - E1) / (1 - E1 * xp.cos(FA_radian)) + + self.FA = FA + self.TR = TR + + self.rvc = rvc + + self.device = device + + self.verbose = verbose + + super().__init__(oshape, ishape, repr_str=repr_str) + + def _forward(self, input): + self.x = _array_to_device(input, device=self.device) + xp = self.device.xp + + with self.device: + CA = self.Patlak(self.x) + x0 = 1. + + FA_radian = self.FA * xp.pi / 180 + + E1CA = xp.exp(-self.TR * (self.R1 + self.R1CA * CA)) + + CA_trans = self.M0_trans * (1 - E1CA) / (1 - E1CA * xp.cos(FA_radian)) + + output = CA_trans + x0 - self.M_steady + + return output + + def _get_Jacobian(self, x): + self.x = x + device = backend.get_device(self.x) + xp = device.xp + + FA_radian = self.FA * xp.pi / 180 + + with device: + CA = self.Patlak(self.x) + + cosFA = xp.cos(FA_radian) + + E1CA = xp.exp(-self.TR * (self.R1 + self.R1CA * CA)) + + dCA = self.R1CA * self.M0_trans *\ + ((self.TR * E1CA * (1 - E1CA * cosFA)) - \ + (1 - E1CA) * self.TR * E1CA * cosFA) \ + / (1 - cosFA * E1CA)**2 + + Jaco = ((self.mult.T) * (dCA.T)).T + + if self.verbose: + print('> dCA shape: ', dCA.shape) + print('> Jaco shape: ', Jaco.shape) + + return Jaco + + def _derivative(self, x, dx): + device = backend.get_device(dx) + xp = device.xp + + with device: + self.Jacobian = self._get_Jacobian(x) + return xp.sum(self.Jacobian * dx, axis=1, keepdims=True) + + def _adjoint(self, x, dy): + device = backend.get_device(dy) + xp = device.xp + + with device: + self.Jacobian = self._get_Jacobian(x) + JH = xp.conjugate(self.Jacobian) + dx = xp.sum(JH * dy, axis=0) + + if self.rvc: + dx = dx.real + 0. * 1j + + return dx diff --git a/sigpy/mri/diff.py b/sigpy/mri/diff.py new file mode 100644 index 00000000..cced418e --- /dev/null +++ b/sigpy/mri/diff.py @@ -0,0 +1,33 @@ +""" +Methods for Diffusion MRI. + +Author: Zhengguo Tan +""" + +# %% +import numpy as np +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D + +# %% Convert Cartesian coordinates to polar coordinates +def cartes2polar(x, y, z): + r = np.sqrt(x**2 + y**2 + z**2) + theta = np.arccos(z / r) + varphi = np.zeros_like(r) + + ind = np.where(x > 0) + varphi[ind] = np.arctan(y[ind] / x[ind]) + + ind = np.where((x < 0) & (y >= 0)) + varphi[ind] = np.arctan(y[ind] / x[ind]) + np.pi + + ind = np.where((x < 0) & (y < 0)) + varphi[ind] = np.arctan(y[ind] / x[ind]) - np.pi + + ind = np.where((x == 0) & (y > 0)) + varphi[ind] = np.pi / 2 + + ind = np.where((x == 0) & (y < 0)) + varphi[ind] = - np.pi / 2 + + return r, theta, varphi \ No newline at end of file diff --git a/sigpy/mri/dims.py b/sigpy/mri/dims.py new file mode 100644 index 00000000..d8f53bee --- /dev/null +++ b/sigpy/mri/dims.py @@ -0,0 +1,9 @@ +"""Generalized data dimensions for multi-purpose MRI image reconstruction. +""" + +DIM_X = -1 +DIM_Y = -2 +DIM_Z = -3 +DIM_COIL = -4 +DIM_ECHO = -5 # For diffusion MRI, it stores the shots per DWI +DIM_TIME = -6 # For diffusion MRI, it stores the diffusion encodings diff --git a/sigpy/mri/dvs.py b/sigpy/mri/dvs.py new file mode 100644 index 00000000..36dc0e5f --- /dev/null +++ b/sigpy/mri/dvs.py @@ -0,0 +1,131 @@ +"""Diffusion Vector Sets (dvs) +for the use of customized diffusion directions on the scanner + +Author: Zhengguo Tan +""" + +import numpy as np +from pathlib import Path + +list_CoordinateSystem = ['xyz', 'prs'] +list_Normalisation = ['unity', 'maximum', 'none'] + +def _read_directions(line_str): + + line_sp1 = line_str.split('=')[1] + line_sp2 = line_sp1.split(']')[0] + + return int(line_sp2) + +def _read_from_list(line_str, list_str): + + line_sp1 = line_str.split('=')[1] + + for l in list_str: + if l in line_sp1: + return l + +def _read_Vector(line_str): + + line_sp1 = line_str.split('[')[1] + line_sp2 = line_sp1.split(']')[0] + + row_ind = int(line_sp2) + + line_sp1 = line_str.split('(')[1] + line_sp2 = line_sp1.split(')')[0] + line_vec = line_sp2.split(',') + + row_arr = [0 for n in range(len(line_vec))] + + for n in range(len(line_vec)): + row_arr[n] = float(line_vec[n]) + + return row_ind, row_arr + +# %% read .dvs file +def read(dvs_file): + FILE = Path(dvs_file) + + if not FILE.exists(): + print('ERROR: can not find ' + dvs_file) + return None + + if not FILE.is_file(): + print('ERROR: ' + dvs_file + ' is not readable') + return None + + file_content = FILE.read_text() + file_lines = file_content.splitlines() + + + list_Params = [] + list_Vector = [] + + for line in file_lines: + + if line.startswith('Vector'): + list_Vector.append(line) + else: + list_Params.append(line) + + + for line in list_Params: + + if 'directions=' in line: + + directions = _read_directions(line) + + elif 'CoordinateSystem' in line: + + CoordinateSystem = _read_from_list(line, list_CoordinateSystem) + + elif 'Normalisation' in line: + + Normalisation = _read_from_list(line, list_Normalisation) + + + Vector = np.zeros((directions, 3)) + + for line in list_Vector: + + row_ind, row_arr = _read_Vector(line) + + Vector[row_ind, :] = row_arr + + + return directions, CoordinateSystem, Normalisation, Vector + + +# %% write .dvs file +def write(file, Vector, CoordinateSystem='xyz', Normalisation='none'): + + num_Vector, num_dirs = Vector.shape + + with open(file, 'w') as fp: + + directions_str = '[directions=' + str(num_Vector) + ']' + fp.write(directions_str + '\n') + + CoordinateSystem_str = 'CoordinateSystem = ' + CoordinateSystem + fp.write(CoordinateSystem_str + '\n') + + Normalisation_str = 'Normalisation = ' + Normalisation + fp.write(Normalisation_str + '\n') + + for n in range(num_Vector): + + Vector_str = 'Vector[' + str("%3d"%n) + '] = ( ' + + for d in range(num_dirs): + + Vector_str = Vector_str + str("%9.6f" % Vector[n, d]) + + if d+1 == num_dirs: + Vector_str = Vector_str + ' )' + else: + Vector_str = Vector_str + ', ' + + fp.write(Vector_str + '\n') + + print('finish writing to file ' + file) \ No newline at end of file diff --git a/sigpy/mri/epi.py b/sigpy/mri/epi.py new file mode 100644 index 00000000..ca266ce6 --- /dev/null +++ b/sigpy/mri/epi.py @@ -0,0 +1,348 @@ +# -*- coding: utf-8 -*- +"""Methods for Echo-Planar Imaging (EPI) acquisition +with an focus on diffusion tensor/kurtosis imaging. + +Author: + Zhengguo Tan +""" +import numpy as np +from sigpy import fourier + +MIN_POSITIVE_SIGNAL = 0.0001 + + +def phase_corr(kdat, pcor, topup_dim=-11): + """perform phase correction. + + Args: + kdat (array): k-space data to be corrected + pcor (array): phase-correction reference data + topup_dim (int): dimension of the top-up and top-down measurements + + Output: + phase-corrected k-space data + + Reference: + Ehses P. https://github.com/pehses/twixtools + """ + col_dim = -1 + + ncol = pcor.shape[col_dim] + npcl = pcor.shape[topup_dim] + + # if three lines are present, + # average the 1st and the 3rd line. + if npcl == 3: + pcors = np.swapaxes(pcor, 0, topup_dim) + pcor1 = pcors[[0, 2], ...] + pcor_odd = np.mean(pcor1, axis=0, keepdims=True) + pcor_eve = pcors[[1], ...] + pcor = np.concatenate((pcor_odd, pcor_eve)) + pcor = np.swapaxes(pcor, 0, topup_dim) + + oshape = list(kdat.shape) + oshape[topup_dim] = 1 + + output = np.zeros(oshape) + + pcor_img = fourier.ifft(pcor, axes=[col_dim]) + kdat_img = fourier.ifft(kdat, axes=[col_dim]) + + slope = np.angle((np.conj(pcor_img[..., 1:]) * pcor_img[..., :-1]) + .sum(col_dim, keepdims=True).sum(-2, keepdims=True)) + x = np.arange(ncol) - ncol//2 + + pcor_fac = np.exp(1j * slope * x) + + kdat_img *= pcor_fac + kdat_img = kdat_img.sum(topup_dim, keepdims=True) + output = fourier.fft(kdat_img, axes=[-1]) + + return output + + +def get_B(b, g): + """Compute B matrix from b values and g vectors + + Args: + b (1D array): b values + g (2D array): g vectors + + Output: + B (array): [gx**2, gy**2, gz**2, + 2*gx*gy, 2*gx*gz, 2*gy*gz] of every pixel + """ + num_g, num_axis = g.shape + + assert num_axis == 3 + assert num_g == len(b) + + gx = g[:, 0] + gy = g[:, 1] + gz = g[:, 2] + + return - b * np.array([gx**2, 2*gx*gy, gy**2, + 2*gx*gz, 2*gy*gz, gz**2]).transpose() + + +def get_B2(b, g): + """For Diffusion Kurtosis: + Compute B2 matrix from b values and g vectors + + Args: + b (1D array): b values + g (2D array): g vectors + + Output: + B (array) + """ + num_g, num_axis = g.shape + + assert num_axis == 3 + assert num_g == len(b) + + gx = g[:, 0] + gy = g[:, 1] + gz = g[:, 2] + + BT = get_B(b, g) + + BK = b * b * np.array([ + gx**4 / 6, + gy**4 / 6, + gz**4 / 6, + 4 * gx**3 * gy / 6, + 4 * gx**3 * gz / 6, + 4 * gy**3 * gx / 6, + 4 * gy**3 * gz / 6, + 4 * gz**3 * gx / 6, + 4 * gz**3 * gy / 6, + gx**2 * gy**2, + gx**2 * gz**2, + gy**2 * gz**2, + 2 * gx**2 * gy * gz, + 2 * gy**2 * gx * gz, + 2 * gz**2 * gx * gy]).transpose() + + return np.concatenate((BT, BK), axis=1) + + +def get_D(B, sig, fit_method='wls', fit_only_tensor=False, + min_signal=0, fit_kt=False): + """Compute D matrix (diffusion tensor) + + Args: + B (array): see above. + sig (array): b0 image and diffusion-weighted images. + fit_method (string): [default: 'wls'] + - 'wls' weighted least square + - 'ols' ordinary least square + fit_only_tensor (boolean): excluding b0 [default: False] + min_signal (float): minimal signal intensity in DWI + [Default: MIN_POSITIVE_SIGNAL] + better set to 0. + fit_kt (boolean): fit kurtosis tensor directly [default: False] + + Output: + D (array): [Dxx, Dxy, Dyy, Dxz, Dyz, Dzz] of every pixel. + Please refer to get_B() and get_B2() for the actual order + of the D array. + + References: + Chung S. W., Lu Y., Henry R. G. (2006). + Comparison of bootstrap approaches for + estimation of uncertainties of DTI parameters. + NeuroImage 33, 531-541. + + DiPy. https://github.com/dipy/dipy + """ + sig = np.abs(sig) + sig = np.maximum(sig, min_signal) + S = np.log(sig, out=np.zeros_like(sig), where=(sig != 0)) + + ndiff = S.shape[0] + image_shape = S.shape[1:] + + if fit_only_tensor is True: + y = S[0, ...] - S + else: + y = S + dummy = np.ones((B.shape[0], 1)) + B = np.concatenate((B, dummy), axis=1) + + nparam = B.shape[1] + yr = y.reshape(ndiff, -1) + + # print('> OLS Fitting') + xr = np.dot(np.linalg.pinv(B), yr) + D_fit = xr.reshape([nparam] + list(image_shape)) + + if fit_method == 'wls': + # print('> WLS Fitting') + + if fit_kt is True: + eigvals, eigvecs = get_eig(D_fit, B) + MD2 = get_MD(eigvals)**2 + scale = np.tile(MD2[None, ...], [nparam] + [1] * len(image_shape)) + scale = np.reshape(scale, (nparam, -1)) + scale[:6, ...] = 1 + scale[-1, ...] = 1 + else: + scale = np.ones_like(xr) + + scale = np.expand_dims(scale.T, axis=1) + + w = np.exp(np.dot(B, xr)).T # weight + + lhs = np.linalg.pinv(B * w[..., None] * scale, rcond=1e-15) + lhs = np.swapaxes(lhs, 0, 1) + + rhs = (w.T * yr).T + + xr = np.sum(lhs * rhs, axis=-1) + + return xr.reshape([nparam] + list(image_shape)) + + +_lt_indices = np.array([[0, 1, 3], + [1, 2, 4], + [3, 4, 5]]) + + +def DT_vec2mat(Dvec): + """Convert the 6 elements of diffusion tensor (DT) + to a 3x3 symmetric matrix + """ + assert 6 == Dvec.shape[0] + + return Dvec[_lt_indices, ...] + + +def get_eig(D, B=None): + """Compute eigenvalues and eigenvectors of the D matrix + + Args: + D (array): output from get_D(B, sig) + + Output: + eigvals: eigenvalues + eigvecs: eigenvectors + """ + image_shape = D.shape[1:] + image_size = np.prod(image_shape) + + Dmat = DT_vec2mat(D[:6, ...]) + temp = np.rollaxis(Dmat, 0, len(Dmat.shape)) + Dmat = np.rollaxis(temp, 0, len(Dmat.shape)) + eigvals, eigvecs = np.linalg.eigh(Dmat) + + # flatten eigvals and eigenvecs + eigvals = eigvals.reshape(-1, 3) + eigvecs = eigvecs.reshape(-1, 3, 3) + + order = eigvals.argsort()[:, ::-1] + + xi = np.ogrid[:image_size, :3][0] + eigvals = eigvals[xi, order] + + xi, yi = np.ogrid[:image_size, :3, :3][:2] + eigvecs = eigvecs[xi, yi, order[:, None, :]] + + eigvals = eigvals.reshape(image_shape + (3, )) + eigvecs = eigvecs.reshape(image_shape + (3, 3)) + + eigvals = np.rollaxis(eigvals, -1, 0) + + eigvecs = np.rollaxis(eigvecs, -1, 0) + eigvecs = np.rollaxis(eigvecs, -1, 0) + + if B is not None: + min_diffusivity = 1e-6 / -B.min() + eigvals = eigvals.clip(min=min_diffusivity) + + return eigvals, eigvecs + + +def get_FA(eigvals): + """Compute Fractional Anisotropy (FA) map + + Args: + eigvals (array): output from get_eig(D) + + Output: + FA (array): FA map + """ + l1 = eigvals[0, ...] + l2 = eigvals[1, ...] + l3 = eigvals[2, ...] + + nomi = 0.5 * ((l1-l2)**2 + (l2-l3)**2 + (l3-l1)**2) + deno = l1**2 + l2**2 + l3**2 + + FA = np.sqrt(np.divide(nomi, deno, + out=np.zeros_like(nomi), + where=deno != 0)) + + return FA + + +def get_cFA(FA, eigvecs): + """Compute color-coded Fractional Anisotropy (cFA) map + + Args: + FA (array): FA map + eigvecs (array): eigen vectors + + Output: + cFA (array): cFA map + """ + return np.abs(eigvecs[:, 0, ...]) * FA + + +def get_MD(eigvals): + """Compute Mean Diffusivity (MD) map + + Args: + eigvals (array): output from get_eig(D) + + Output: + MD (array): MD map + """ + assert 3 == eigvals.shape[0] + + return np.mean(eigvals, axis=0) + + +def get_KT(D, B=None): + """Compute Kurtosis Tensor (KT) map + + Args: + D (array): output from get_D(B, sig) + + Output: + KT (array): KT map + """ + assert 21 <= D.shape[0] + DT = D[:6, ...] + DK = D[6:21, ...] + + eigvals, eigvecs = get_eig(DT, B=B) + MD = get_MD(eigvals) + + return DK / (MD**2) + + +def get_ADC(D): + """Compute the apparent diffusion coefficient (ADC) map + + Args: + D (array): diffusion tensor + + Output: + ADC (array): ADC map + """ + Dxx = D[0, ...] + Dyy = D[2, ...] + Dzz = D[5, ...] + + return (Dxx + Dyy + Dzz) / 3 \ No newline at end of file diff --git a/sigpy/mri/gmap.py b/sigpy/mri/gmap.py new file mode 100644 index 00000000..75b3fb01 --- /dev/null +++ b/sigpy/mri/gmap.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +"""Methods for computing the geometry factor (g-map) + +Author: + Zhengguo Tan +""" +import copy +import numpy as np +from sigpy import backend, util +from sigpy.mri import app + + +def pseudo_replica(app, y, mps, + normalize=True, replicas=10): + r""" + Performs multiple reconstruction with the supplied + mri reconstruction app and measured k-space data y. + + Args: + app: an app for image reconstruction (e.g. SenseRecon). + y (array): measured k-space data. + mps (array): coil sensitivity maps. + normalize (Boolean): normalize y or not [default:True]. + replicas (integer): numer of replicas to be ran for the calculation of the g map. + + Reference: + http://hansenms.github.io/sunrise/sunrise2013/ + """ + device = backend.get_device(y) + xp = device.xp + + if normalize is True: + with device: + s = xp.linalg.norm(y) + y = 1E6 * y / s + + with device: + yvec = y.flatten() + ind = yvec.nonzero() + y1 = yvec[ind] + rshape = ind[0].shape + + res = [] + for r in range(replicas): + + n = util.randn(rshape, scale=1, dtype=y.dtype, device=device) + + with device: + yn = y1 + 0.1 * n + yv = xp.zeros_like(yvec) + yv[ind] = yn + ym = yv.reshape(y.shape) + + # cfl.writecfl('test_y_' + str(r), ym) + + # pass noisey y to app + curr_app = app + curr_app.y = xp.zeros_like(ym) + imgn = curr_app.run() + + res.append(imgn) + + res = backend.to_device(xp.array(res)) + + sca = np.max(np.abs(res)) + g = np.std(res + sca, axis=0) + + # g = g * util.rss(mps) + + return g, res \ No newline at end of file diff --git a/sigpy/mri/grappa.py b/sigpy/mri/grappa.py new file mode 100644 index 00000000..57e74bc2 --- /dev/null +++ b/sigpy/mri/grappa.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +"""Functions for GRAPPA. + +This implementation makes use of the function "sp.array_to_blocks" to slide through calibration as well as undersampled k-space data to perform GRAPPA kernel fitting and reconstruction. + +Alternatively, one can also use other functions, e.g. "view_as_windows" in skimage.util. + +Reference: + https://users.fmrib.ox.ac.uk/~mchiew/Tools.html + https://github.com/mckib2/pygrappa + +Author: + Zhengguo Tan +""" +import numpy as np +import sigpy as sp + +__all__ = ['Grappa', 'SliceGrappa'] + + +def _set_kspace_boundary(kspace, pad_shape): + + py, px = pad_shape + kspace[..., :px] = kspace[..., -px*2:-px] + kspace[..., -px:] = kspace[..., px:2*px] + kspace[..., :py, :] = kspace[..., -2*py:-py, :] + kspace[..., -py:, :] = kspace[..., py:2*py, :] + + return kspace + + +def Grappa(kspace, calib, R=[2, 1], lamda=1E-4, + kernel_shape=[4, 5], + kernel_stride=[1, 1]): + """ + Args: + kspace (array): undersampled 2D k-space data. + calib (array): fully-sampled calibration data. + R (scalar or length-2 list): acceleration factor. + lamda (float): weight fitting regularization. + kernel_shape (length-2 list): [ker_y (even), ker_x (odd)]. + kernel_stride (length-2 list): kernel stride. + + Reference: + Griswold M. A., Jakob P. M., Heidemann R. M., Nittka M., Jellus V., Wang J., Kiefer B., Haase A. (2002). + Generalized autocalibrating partially parallel acquisitions (GRAPPA). + Magn. Reson. Med., 47, 1202-1210. + """ + acc_y, acc_x = R[:] + assert(acc_x == 1) # 1D Grappa with only ky undersampling + + NC, NY, NX = kspace.shape + _C, _Y, _X = calib.shape + + assert(NC == _C) + + ker_y, ker_x = kernel_shape[:] + assert((ker_y % 2 == 0) and (ker_x % 2 == 1)) + + ker_acc_y, ker_acc_x = acc_y * ker_y, acc_x * ker_x + pad_y, pad_x = ker_acc_y // 2 - 1, ker_acc_x // 2 + + # %% train Grappa kernel weights + + None + +def SliceGrappa(kspace, calib, R=2, lamda=1E-4, + kernel_shape=[5, 5], + kernel_stride=[1, 1]): + """ + Args: + kspace (array): undersampled collapsed SMS k-space data. + calib (array): fully-sampled multi-slice calibration data. + lamda (float): weight fitting regularization. + kernel_shape (length-2 list): [ker_y, ker_x] - both must be odd. + kernel_stride (length-2 list): kernel stride. + + Reference: + Setsompop K., Gagoski B. A., Polimeni J. R., Witzel T., Wedeen V. J., Wald L. L. (2012). + Blipped-controlled aliasing in parallel imaging for simultaneous multislice echo planar imaging with reduced g-factor penalty. + Magn. Reson. Med., 67, 1210-1224. + """ + NC, NY, NX = kspace.shape + NS, _NC, _NY, _NX = calib.shape + + assert(NC == _NC) + + ker_y, ker_x = kernel_shape[:] # kernel y and x + pad_y, pad_x = ker_y//2, ker_x//2 # pad y and x + + kernel_len = np.prod(kernel_shape) + weight_len = kernel_len * NC + + # pad kspace and calib + kshape = kspace.shape + kpad_shape = list(kshape[:-2]) + \ + [NY + 2*pad_y] + [NX + 2*pad_x] + + kspace_pad = sp.util.resize(kspace, kpad_shape) + kspace_pad = _set_kspace_boundary(kspace_pad, [pad_y, pad_x]) + + cshape = calib.shape + cpad_shape = list(cshape[:-2]) + \ + [_NY - 2*pad_y] + [_NX - 2*pad_x] + calib_pad = sp.util.resize(calib, cpad_shape) + + + # %% Train kernels slice by slice + + # construct source blocks from calib data + A2B = sp.linop.ArrayToBlocks(cshape, kernel_shape, kernel_stride) + src = A2B(calib) + # src = np.sum(src, axis=0) # sum over the slice axis + src = np.transpose(src, (0, 2, 3, 1, 4, 5)) + + assert(np.prod(src.shape[-3:]) == weight_len) + + src = np.reshape(src, (NS, -1, weight_len)) + src = np.sum(src, axis=0) + + # construct target blocks + A2B = sp.linop.ArrayToBlocks(cpad_shape, [1,1], [1,1]) + trg = A2B(calib_pad) + trg = np.reshape(trg, (NS, NC, np.prod(trg.shape[-4:]))) + + W = np.zeros_like(calib, shape=(NS, NC, weight_len)) + + for s in range(NS): + x1 = src.T + y1 = trg[s, ...] + + SHS = x1 @ x1.T.conj() + lamda * np.linalg.norm(x1) * np.eye(weight_len) + x2 = x1.T.conj() @ np.linalg.pinv(SHS) + + W[s, ...] = y1 @ x2 + + # %% apply trained weights to un-collapse kspace data + + # construct source blocks from kspace data + A2B = sp.linop.ArrayToBlocks(kpad_shape, kernel_shape, kernel_stride) + src = A2B(kspace_pad) + src = np.transpose(src, (1, 2, 0, 3 ,4)) + src = np.reshape(src, (-1, weight_len)).T + + res = np.zeros_like(kspace, shape=(NS, NC, NY, NX)) + for s in range(NS): + res[s, ...] = (W[s, ...] @ src).reshape((-1, NY, NX)) + + return res diff --git a/sigpy/mri/linop.py b/sigpy/mri/linop.py index 90f6988a..e3d22759 100644 --- a/sigpy/mri/linop.py +++ b/sigpy/mri/linop.py @@ -6,6 +6,7 @@ discrete Fourier transform. """ +import numpy as np import sigpy as sp @@ -18,7 +19,8 @@ def Sense( coil_batch_size=None, comm=None, transp_nufft=False, -): + basis=None, + phase=None): """Sense linear operator. Args: @@ -40,14 +42,17 @@ def Sense( """ # Get image shape and dimension. num_coils = len(mps) - if ishape is None: - ishape = mps.shape[1:] - img_ndim = mps.ndim - 1 + img_shape = list(mps.shape[1:]) + img_ndim = mps.ndim - 1 + + if basis is None: + if ishape is None: + ishape = img_shape else: - img_ndim = len(ishape) + if ishape is None: + ishape = [basis.shape[-1]] + [1] + list(mps.shape[-2:]) # Serialize linop if coil_batch_size is smaller than num_coils. - num_coils = len(mps) if coil_batch_size is None: coil_batch_size = num_coils @@ -72,23 +77,40 @@ def Sense( return A + # linear subspace + if basis is None: + B = sp.linop.Identity(ishape) + else: + sub_ishape = (basis.shape[1], np.prod(ishape[1:])) + sub_oshape = [basis.shape[0]] + ishape[1:] + B1 = sp.linop.Reshape(sub_ishape, ishape) + B2 = sp.linop.MatMul(sub_ishape, basis) + B3 = sp.linop.Reshape(sub_oshape, B2.oshape) + B = B3 * B2 * B1 + + # model image phase evolution + if phase is None: + PHI = sp.linop.Identity(B.oshape) + else: + PHI = sp.linop.Multiply(B.oshape, phase) + # Create Sense linear operator - S = sp.linop.Multiply(ishape, mps) + S = sp.linop.Multiply(PHI.oshape, mps) if tseg is None: if coord is None: F = sp.linop.FFT(S.oshape, axes=range(-img_ndim, 0)) else: if transp_nufft is False: - F = sp.linop.NUFFT(S.oshape, coord) + F = sp.linop.HDNUFFT(S.oshape, coord) else: - F = sp.linop.NUFFT(S.oshape, -coord).H + F = sp.linop.HDNUFFT(S.oshape, -coord).H - A = F * S + A = F * S * PHI * B # If B0 provided, perform time-segmented off-resonance compensation else: if transp_nufft is False: - F = sp.linop.NUFFT(S.oshape, coord) + F = sp.linop.HDNUFFT(S.oshape, coord) else: F = sp.linop.NUFFT(S.oshape, -coord).H time = len(coord) * tseg["dt"] diff --git a/sigpy/mri/muse.py b/sigpy/mri/muse.py new file mode 100644 index 00000000..a1808bcb --- /dev/null +++ b/sigpy/mri/muse.py @@ -0,0 +1,305 @@ +""" +MUSE Reconstruction. + +Author: Zhengguo Tan +""" + +import numpy as np +import sigpy as sp + +from sigpy import backend +from sigpy.mri import app, sms +from sigpy.mri.dims import * + + +# %% +def _denoising(input, full_img_shape=None, use_iter=True, max_iter=5): + """ + Args: + input: acs images + full_img_shape: shape of full-FOV images + """ + + # # Hanning + if full_img_shape is None: + full_img_shape = input.shape[-2:] + + device = backend.get_device(input) + xp = device.xp + + with device: + + H = sp.hanning(input.shape[-2:], dtype=complex, symm=True, + device=device) + + H_full = sp.resize(H, full_img_shape) + + k_full = sp.resize(sp.fft(input, axes=[-2, -1]), + oshape=list(input.shape[:-2]) + + list(full_img_shape)) + + + if use_iter: + for m in range(max_iter): + k_full = H_full * k_full + else: + k_full = H_full**max_iter * k_full + + + img = sp.ifft(k_full, axes=[-2, -1]) + + idx = abs(img) > 0 + phs = xp.zeros_like(img) + phs[idx] = img[idx] / abs(img[idx]) + + return img, phs + + +def sms_sense_linop(kdat, coils, yshift, phase_echo=None, + real_value_constraint=False): + + device = backend.get_device(kdat) + assert (device == backend.get_device(coils)) + + Ncoil, Nz, Ny, Nx = coils.shape + + assert (Nz == len(yshift)) + + phase_sms = sms.get_sms_phase_shift([Nz, Ny, Nx], Nz, yshift=yshift) + + img_shape = [1, Nz, Ny, Nx] + + if real_value_constraint is True: + RVC = sp.linop.RealValueConstraint(img_shape) + else: + RVC = sp.linop.Identity(img_shape) + + if phase_echo is not None: + + P = sp.linop.Multiply(img_shape, + sp.to_device(phase_echo, device=device)) + + else: + + P = sp.linop.Identity(img_shape) + + # coils + S = sp.linop.Multiply(P.oshape, coils) + + # FFT + F = sp.linop.FFT(S.oshape, axes=range(-2, 0)) + + # SMS + PHI = sp.linop.Multiply(F.oshape, sp.to_device(phase_sms, device=device)) + SUM = sp.linop.Sum(PHI.oshape, axes=(DIM_Z, ), keepdims=True) + + M = SUM * PHI + + weights = app._estimate_weights(kdat, None, None, coil_dim=DIM_COIL) + + W = sp.linop.Multiply(M.oshape, weights**0.5) + + return W * M * F * S * P * RVC + + +def sms_sense_solve(A, y, lamda=0.01, tol=0, max_iter=30, verbose=False): + + device = backend.get_device((y)) + xp = device.xp + + AHA = lambda x: A.N(x) + lamda * x + AHy = A.H(y) + + img = xp.zeros(A.ishape, dtype=y.dtype) + alg_method = sp.alg.ConjugateGradient(AHA, AHy, img, + tol=tol, + max_iter=max_iter, verbose=verbose) + + while (not alg_method.done()): + alg_method.update() + + return img + + +# %% +def MuseRecon(y, coils, MB=1, acs_shape=[64, 64], + lamda=0.001, max_iter=80, tol=0, + use_readout_extend_fov=False, yshift=None, + real_value_constraint=False, + device=sp.cpu_device, verbose=False): + """ + MUSE is a novel method to reconstruct one diffusion-weighted image (DWI) + from multi-shot EPI acquisition. It consists of the following steps: + 1. shot-by-shot SENSE recon; + 2. phase estimation from every shot image; + 3. incorporate phase into phase-informed SENSE recon to obtain one DWI. + + Args: + y (array): zero-filled k-space data with shape: + [Nshot, Ncoil, Nz_collap, Ny, Nx], where + - Nshot: # of shots per DWI, + - Ncoil: # of coils, + - Nz_collap: # of collapsed slices, + - Ny: # of phase-encoding lines, + - Nx: # of readout lines. + + coils (array): coil sensitivity maps with shape: + [Ncoil, Nz, Ny, Nx], where + - Nz: # of un-collapsed slices. + + MB (int): multi-band factor + MB = Nz / Nz_collap. + + acs_shape (tuple of ints): shape of the auto-calibration signal (ACS), + which is used for the shot-by-shot SENSE recon. + + References: + * Liu C, Moseley ME, Bammer R. + Simultaneous phase correction and SENSE reconstruction for navigated multi-shot DWI with non-Cartesian k-space sampling. + Magn Reson Med 2005;54:1412-1422. + + * Chen NK, Guidon A, Chang HC, Song AW. + A robust multi-shot strategy for high-resolution diffusion weighted MRI enabled by multiplexed sensitivity-encoding (MUSE). + NeuroImage 2013;72:41-47. + """ + Ndiff, Nshot, Ncoil, Nz_collap, Ny, Nx = y.shape + assert(Nshot > 1) # MUSE is a multi-shot technique + + _Ncoil, Nz, _Ny, _Nx = coils.shape + + assert ((Ncoil == _Ncoil) and (Ny == _Ny) and (Nx == _Nx)) + assert ((Nz_collap == Nz / MB)) + + phi = sms.get_sms_phase_shift([MB, Ny, Nx], MB, yshift=yshift) + + if acs_shape is None: + + ksp_acs = y.copy() + mps_acs = coils.copy() + + else: + + ksp_acs = sp.resize(y, oshape=list(y.shape[:-2]) + list(acs_shape)) + + import torchvision.transforms as T + + coils_tensor = sp.to_pytorch(coils) + TR = T.Resize(acs_shape, antialias=True) + mps_acs_r = TR(coils_tensor[..., 0]).cpu().detach().numpy() + mps_acs_i = TR(coils_tensor[..., 1]).cpu().detach().numpy() + mps_acs = mps_acs_r + 1j * mps_acs_i + + print('**** MUSE - ksp_acs shape ', ksp_acs.shape) + print('**** MUSE - mps_acs shape ', mps_acs.shape) + + R_muse = [] + R_shot = [] + for z in range(Nz_collap): # loop over collapsed k-space + + slice_idx = sms.get_uncollap_slice_idx(Nz, MB, z) + mps_acs_slice = mps_acs[:, slice_idx, ...] + + for d in range(Ndiff): + + print('>> muse on slice ' + str(z).zfill(2) + ' diff ' + str(d).zfill(3)) + + if use_readout_extend_fov: + + # 1. perform shot-by-shot ACS SENSE recon to estimate phase + img_ext_shots = [] + for s in range(Nshot): # loop over every shot + + ksp = ksp_acs[d, s, :, z, ...] + + ksp_ext, mps_ext, _ = sms.readout_extended_fov(ksp, mps_acs_slice, MB) + + img_ext = app.SenseRecon(ksp_ext, mps_ext, 5E-5, + max_iter=90, tol=0, + device=device).run() + + img_ext_shots.append(backend.to_device(img_ext)) + + img_ext_shots = np.array(img_ext_shots) + R_shot.append(img_ext_shots) + + # 2. phase estimation from shot images + img_ext_shots_den, phs_ext_shots = _denoising(img_ext_shots, full_img_shape=[Ny, Nx * MB]) + + img_ini = abs(np.mean(img_ext_shots_den * np.conj(phs_ext_shots), axis=0)).astype(phs_ext_shots.dtype) + + # 3. perform phase-informed SENSE recon + # to estimate shot-combined DWI + ksp = y[d, :, :, z, ...] + mps = coils[:, slice_idx, ...] + ksp_ext, mps_ext, _ = sms.readout_extended_fov(ksp, mps, MB) + + phs_ext_shots = np.expand_dims(phs_ext_shots, axis=1) + phs_ext_shots_mps = phs_ext_shots * mps_ext + phs_ext_shots_mps = phs_ext_shots_mps.reshape((-1, Ny, Nx * MB)) + + # -- calculate weights for k-space + arr1 = np.ones((Ncoil, Ny, Nx * MB)) + weights = (sp.rss(ksp_ext, axes=(1, ), keepdims=True) > 0.).astype(ksp_ext.dtype) + weights = arr1 * weights + + ksp_ext = ksp_ext.reshape((-1, Ny, Nx * MB)) + weights = weights.reshape((-1, Ny, Nx * MB)) + + img_ext = app.SenseRecon(ksp_ext, phs_ext_shots_mps, + lamda, max_iter=max_iter, tol=tol, + weights=weights, + x=img_ini, + device=device).run() + + img = sms.readout_unextend_fov(backend.to_device(img_ext), MB) + # img = sp.ifft(np.conj(phi) * sp.fft(img, axes=[-2, -1]), axes=[-2, -1]) + R_muse.append(img) + + else: + + xp = device.xp + ksp_acs = sp.to_device(ksp_acs, device=device) + mps_acs_slice = sp.to_device(mps_acs_slice, device=device) + + y = sp.to_device(y, device=device) + coils = sp.to_device(coils, device=device) + + # 1. perform shot-by-shot ACS SENSE recon to estimate phase + img_acs_shots = [] + for s in range(Nshot): + + ksp = ksp_acs[d, s, :, z, :, :] + ksp = ksp[..., None, :, :] + + A = sms_sense_linop(ksp, mps_acs_slice, yshift) + + img = sms_sense_solve(A, ksp, lamda=5E-5, tol=0, + max_iter=max_iter, verbose=verbose) + + img_acs_shots.append(img) + + img_acs_shots = xp.array(img_acs_shots) + R_shot.append(sp.to_device(img_acs_shots)) + + # 2. phase estimation from shot images + _, phs_shots = _denoising(img_acs_shots, + full_img_shape=[Ny, Nx]) + + # 3. perform phase-informed SENSE recon + # to estimate shot-combined DWI + ksp = y[d, :, :, z, ...] + ksp = ksp[..., None, :, :] + mps = coils[:, slice_idx, ...] + + A = sms_sense_linop(ksp, mps, yshift, phase_echo=phs_shots, + real_value_constraint=real_value_constraint) + + img = sms_sense_solve(A, ksp, lamda=lamda, tol=tol, + max_iter=max_iter, verbose=verbose) + + R_muse.append(sp.to_device(img)) + + R_muse = np.array(R_muse) + R_shot = np.array(R_shot) + + return R_muse, R_shot diff --git a/sigpy/mri/mussels.py b/sigpy/mri/mussels.py new file mode 100644 index 00000000..a337eff0 --- /dev/null +++ b/sigpy/mri/mussels.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- +""" +MUSSELS Diffusion MRI Reconstruction. + +Author: Zhengguo Tan +""" + +import numpy as np +import sigpy as sp + +from sigpy import backend +from sigpy.mri import app, sms + + +def MusselsRecon(y, coils, MB=1, + lamda=0.05, max_iter=50, rho=0.1, + regu='SLR', regu_kspace=True, thresh='hard', + blk_shape=[1, 7, 7], blk_strides=[1, 1, 1], + use_readout_extend_fov=False, ro_extend_fold=1, + yshift=None, + verbose=True, + device=sp.cpu_device): + """MUSSELS Reconstruction. + + Args: + y (array): measured k-space data. + coils (array): coil sensitivity maps. + MB (int): multi-band factor. + lamda (float): regularization strength. + max_iter (int): maximal iteration steps. + rho (float): ADMM rho. + regu (string): regularization type ('SLR', 'TIK', 'LLR', 'TV'). + regu_kspace (boolean): regularization performed in k-space. + thresh (string): thresholding type ('hard', 'soft'). + blk_shape (tuple or list of int): block shape (default: [7, 7]). + blk_strides (tuple or list of int): stride shape (default: [1, 1]). + use_readout_extend_fov (boolean): use the method + readout extended fov for image reconstruction. + ro_extend_fold (int): unextend readout extended FOV + for regularization (1 or MB). + verbose (boolean): output ADMM iteration info. + device (int): cpu or gpu device to run the recon. + + References: + * Mani M, Jacob M, Kelley D, Magnotta V. + Multi-shot sensitivity-encoded diffusion data recovery using + structured low-rank matrix completion (MUSSELS). + Magn Reson Med 2017;78:494-507. + + * Bilgic B, Chatnuntawech I, Manhard MK, Tian Q, Liao C, Iyer SS, + Cauley SF, Huang SY, Polimeni JR, Wald LL, Setsompop K. + High accelerated multishot echo planar imaging through + synergistic machine learning and joint reconstruction. + Magn Reson Med 2019;82:1343-1358. + + * Mani M, Aggarwal HK, Magnotta V, Jacob M. + Improved MUSSELS reconstruction for high-resolution + multi-shot diffusion weighted imaging. + Magn Reson Med 2020;83:2253-2263. + + * Dai E, Mani M, McNab JA. + Multi-band multi-shot diffusion MRI reconstruction + with joint usage of structured low-rank constraints + and explicit phase mapping. + Magn Reson Med 2023;89:95-111. + """ + Ndiff, Nshot, Ncoil, Nz_collap, Ny, Nx = y.shape + assert (Nshot > 1) # MUSSELS is a multi-shot technique + + _Ncoil, Nz, _Ny, _Nx = coils.shape + + assert ((Ncoil == _Ncoil) and (Ny == _Ny) and (Nx == _Nx)) + assert ((Nz_collap == Nz / MB)) + + output = [] + for z in range(Nz_collap): # loop over collapsed k-space + + slice_idx = sms.get_uncollap_slice_idx(Nz, MB, z) + mps = coils[:, slice_idx, ...] + + ksp_slice = y[..., z, :, :] # 5 + + for d in range(Ndiff): # loop over diffusion encodings (DWI) + + print('>> mussels on slice ' + str(z).zfill(2) + + ' diff ' + str(d).zfill(3)) + + ksp = ksp_slice[d, ...] # 4 + + # readout extended FOV + if use_readout_extend_fov is True: + + ksp_ext, mps_ext, msk_ext = sms.readout_extended_fov(ksp, + mps, + MB) + + ksp_ext = np.expand_dims(ksp_ext, axis=(-3, 0)) # 6 + msk_ext = np.expand_dims(msk_ext, axis=(-3, 0)) # 6 + mps_ext = np.expand_dims(mps_ext, axis=(-3, )) + + sms_phase = sms.get_sms_phase_shift([1, Ny, Nx * MB], + MB=1, + yshift=[0]) + + else: + + ksp_ext = ksp[None, :, :, None, :, :] # 6 dim + mps_ext = mps.copy() + msk_ext = None + + sms_phase = sms.get_sms_phase_shift([MB, Ny, Nx], + MB=MB, + yshift=yshift) + + # structured low rank matrix completion (SLRMC) as a regularizer + img_ext = app.HighDimensionalRecon( + ksp_ext, mps_ext, lamda, max_iter=max_iter, + combine_echo=False, weights=msk_ext, + phase_sms=sms_phase, + regu=regu, regu_kspace=regu_kspace, + blk_shape=blk_shape, blk_strides=blk_strides, + thresh=thresh, solver='ADMM', rho=rho, + ro_extend_fold=ro_extend_fold, + verbose=verbose, + device=device).run() + + img = backend.to_device(img_ext) + + if use_readout_extend_fov is True: + img = sms.readout_unextend_fov(img, MB) + phi = sms.get_sms_phase_shift(img.shape, MB, yshift=yshift) + img = sp.ifft(np.conj(phi) * sp.fft(img, axes=[-2, -1]), + axes=[-2, -1]) + + output.append(img) + + output = np.array(output) + + return output diff --git a/sigpy/mri/nlop.py b/sigpy/mri/nlop.py new file mode 100644 index 00000000..ff9d7a85 --- /dev/null +++ b/sigpy/mri/nlop.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- +"""MRI non-linear operators. + +This module contains these non-linear operators: + + * Nlinv, + joint coil sensitivity maps and image content. + + * Diffusion, + exponential diffusion modelling and parallel imaging sampling. + +Author: Zhengguo Tan +""" +import sigpy as sp + +from sigpy import nlop +from sigpy.mri import linop + + +class Nlinv(sp.nlop.Nlop): + """ + Construction of the non-linear parallel imaging (nlinv) operator. + + Given the unknown x = (rho, c_1, ..., c_N)^T + , where + rho: image content, and + c_1, ..., c_N: N coil sensitivity maps, + + the forward operation is, + + F(x) = ( ..., FT{rho * c_n}, ... )^T + + , where + n in [1, N], and + FT is either masked FFT or NUFFT. + + Args: + image_shape (tuple): shape of image. + coil_shape (tuple): shape of coils. + coord (None or array): coordinates, i.e. trajectories + coil (None or array): coil sensitivity maps. + W_coil (boolean): apply Sobolev weight on coil or not. + upd_coil (boolean): update coil sensitivity maps or not. + + Reference: + Bauer F., Kannengiesser S. (2007). + An alternative approach to the image reconstruction + for parallel data acquisition in MRI. + Math. Methods Appl. Sci., 30, 1437-1451. + + Uecker M., Hohage T., Block K. T., Frahm J. (2008). + Image reconstruction by regularized nonlinear inversion - + joint estimation of coil sensitivities and image content. + Magn. Reson. Med., 60, 674-682. + """ + + def __init__(self, image_shape, coil_shape, + coord=None, coil=None, + W_coil=True, upd_coil=True, + repr_str=None): + self.image_shape = image_shape + self.coil_shape = coil_shape + + ishape = self._get_xshape() + + self.coord = coord + self.coil = coil + self.upd_coil = upd_coil + + # Sobolev linear operator on coils + if W_coil: + self.W = sp.linop.Sobolev(self.coil_shape) + else: + self.W = sp.linop.Identity(self.coil_shape) + + # FFT or NUFFT operator + x_ndim = len(ishape) + if coord is None: + self.F = sp.linop.FFT(self.coil_shape, axes=range(-x_ndim+1, 0)) + else: + self.F = sp.linop.NUFFT(self.coil_shape, coord) + + oshape = self.F.oshape + + super().__init__(oshape, ishape, repr_str) + + def _get_xshape(self): + + image_ndim = len(self.image_shape) + + if image_ndim == 2: + num_coilimg = 1 + self.coil_shape[0] + else: + num_coilimg = self.image_shape[0] + self.coil_shape[0] + + xshape = [] # empty list + xshape.append(num_coilimg) + + return xshape + list(self.image_shape[-2:]) + + def _forward(self, input): + with sp.backend.get_device(input): + + # store the current estimate into class + self.x = input + + image = self.x[0, :, :] # extract image + coil_ksp = self.x[1:, :, :] # extract coils + coil_img = self.W * coil_ksp + + return self.F(image * coil_img) + + def _get_Jacobian(self, x): + return None + + def _derivative(self, x, dx): + device = sp.backend.get_device(dx) + + self.x = x + + with device: + image = self.x[0, :, :] + coil_ksp = self.x[1:, :, :] + coil_img = self.W * coil_ksp + + dimage = dx[0, :, :] + dcoil_ksp = dx[1:, :, :] + dcoil_img = self.W * dcoil_ksp + + return self.F * (dimage * coil_img + image * dcoil_img) + + def _adjoint(self, x, dy): + device = sp.backend.get_device(dy) + xp = device.xp + + self.x = x + + output = xp.zeros_like(self.x) + + with device: + image = self.x[0, :, :] + coil_ksp = self.x[1:, :, :] + coil_img = self.W * coil_ksp + + dcoilimg = self.F.H * dy + + output[0, :, :] = xp.sum(xp.conj(coil_img) * dcoilimg, axis=0) + + if self.upd_coil: + output[1:, :, :] = self.W.H(xp.conj(image) * dcoilimg) + + return output + + +# class Diffusion(sp.nlop.Nlop): +def Diffusion(input_shape, diff_enc, coil, + scale=None, rvc=False, dwi_phase=None, + coord=None, weights=None): + """ + Construction of the non-linear Diffusion operator. + + Given the unknown x = D + , where + D: diffusion tensor, + + the forward operation is, + + A(x) = E * P, and E = F * S + + , where + E (linop): the SENSE linear operator, + F (linop): k-space sampling operator, + S (linop): multiply with coil sensitivity maps, and + P (nlop): exponential diffusion model. + + Args: + input_shape (tuple): shape of input images. + diff_enc (array): diffusion encoding matrix, + i.e. the output matrix from sp.mri.epi.get_B(). + coil (array): coil sensitivity maps. + coord (None or array): coordinates, i.e. trajectories. + """ + const_b0 = True if input_shape[0] == 6 or input_shape[0] == 21 else False + + P = nlop.Exponential(input_shape, diff_enc, + const_a=const_b0, rvc=rvc, scale=scale) + + # phase correction for every diffusion-weighted image + if dwi_phase is not None: + I = sp.linop.Multiply(P.oshape, dwi_phase) + else: + I = sp.linop.Identity(P.oshape) + + # parallel imaging forward model (Sense) + E = linop.Sense(coil, ishape=P.oshape, coord=coord, weights=weights) + + # Compose (i.e. Chain) Sense linop with Diffusion nlop + A = E * I * P + A.repr_str = 'Diffusion' + + return A diff --git a/sigpy/mri/nlrecon.py b/sigpy/mri/nlrecon.py new file mode 100644 index 00000000..8f9fb53f --- /dev/null +++ b/sigpy/mri/nlrecon.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +"""Non-Linear MRI reconstruction applications. +""" +import numpy as np + +import sigpy as sp +from sigpy import backend, nlls, fourier +from sigpy.mri import nlop, app, epi + + +class kinv(nlls.NonLinearLeastSquares): + r"""non-linear mri reconstruction. + + Args: + """ + def __init__(self, y, image_shape, coil_shape, + coil=None, W_coil=True, upd_coil=True, + coord=None, x=None, x0=None, + rvc=False, dwi_phase=None, weights=None, + model='Diffusion', sample_time=None, + device=backend.cpu_device, + outer_iter=6, alpha=1., redu=2., + trafos=None, proxf=None, + inner_iter=100, inner_tol=0.01, + scaling=False, + **kwargs): + + y = sp.to_device(y, device=device) + + xp = device.xp + + if model == 'Nlinv': + A = nlop.Nlinv(image_shape, coil_shape, + coord=coord, coil=coil, W_coil=W_coil, + upd_coil=upd_coil) + + if x is None: + with device: + x = xp.ones(A.ishape) * 0.1 + + elif model == 'Diffusion': + # estimate coil + if coil is None: + None + + with device: + weights = app._estimate_weights(y, weights, coord) + + A = nlop.Diffusion(image_shape, sample_time, coil, + rvc=rvc, dwi_phase=dwi_phase, + weights=weights) + + if x is None: + with device: + if image_shape[0] == 6 or image_shape[0] == 21: + x = xp.ones(A.ishape, dtype=y.dtype) * 0. + else: + x_b0 = xp.ones([1] + list(image_shape[1:]), dtype=y.dtype) * 1E-5 + x_D = xp.zeros([image_shape[0]-1] + list(image_shape[1:]), dtype=y.dtype) + + x = xp.concatenate((x_b0, x_D)) + + if x0 is None: + with device: + x0 = 0.9 * x + # x0 = xp.zeros(A.ishape, dtype=y.dtype) + + super().__init__(A, y, x=x, x0=x0, + outer_iter=outer_iter, + alpha=alpha, redu=redu, + trafos=trafos, proxf=proxf, + inner_iter=inner_iter, + inner_tol=inner_tol, + **kwargs) + + def _estimate_scaling(self): + """Estimate scaling of unknowns. + """ + return None + + def run(self): + """Run non-linear reconstruction. + """ + while not self.done(): + self.update() + + return self.x diff --git a/sigpy/mri/retro.py b/sigpy/mri/retro.py new file mode 100644 index 00000000..88639ef4 --- /dev/null +++ b/sigpy/mri/retro.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +"""Methods for Echo-Planar Imaging (EPI) acquisition: + +* retrospectively undersampling phase-encoding direction + +Author: + Zhengguo Tan +""" +import numpy as np + + +def unsamp_ky(kdat, phaenc_axis=-3, + uniform_unsamp=True, unsamp_factor=2): + # find valid phase-encoding lines + kdat1 = np.swapaxes(kdat, phaenc_axis, 0) + kdat2 = np.reshape(kdat1, (kdat1.shape[0], -1)) + kdat3 = np.sum(kdat2, axis=1) + + sampled_phaenc_ind = np.array(np.nonzero(kdat3)).ravel() + sampled_phaenc_len = len(sampled_phaenc_ind) + + loop_shape = [np.prod(kdat.shape[:phaenc_axis])] \ + + list(kdat.shape[phaenc_axis:]) + kdat4 = np.reshape(kdat, loop_shape) + + output = np.zeros_like(kdat4) + + shift_cnt = 0 + + # loop over all high dimensions + for l in range(loop_shape[0]): + if uniform_unsamp: + rinds = np.arange(shift_cnt, sampled_phaenc_len, unsamp_factor) + shift_cnt = (shift_cnt + 1) % unsamp_factor + else: + rinds = np.random.randint(sampled_phaenc_len, + size=(sampled_phaenc_len + // unsamp_factor)) + + rand_unsamp_lines = sampled_phaenc_ind[rinds] + + tmp = kdat4[l, rand_unsamp_lines, ...] + output[l, rand_unsamp_lines, ...] = tmp + + return np.reshape(output, kdat.shape) + + +def split_shots(kdat, phaenc_axis=-2, shots=2): + """split shots within one diffusion encoding + """ + # find valid phase-encoding lines + kdat1 = np.swapaxes(kdat, phaenc_axis, 0) + kdat2 = np.reshape(kdat1, (kdat1.shape[0], -1)) + + kdat3 = np.sum(kdat2, axis=1) + sampled_phaenc_ind = np.array(np.nonzero(kdat3)).ravel() + sampled_phaenc_len = len(sampled_phaenc_ind) + + out_shape = [shots] + list(kdat2.shape) + output = np.zeros_like(kdat2, shape=out_shape) + + for l in range(sampled_phaenc_len): + s = l % shots + + ind = sampled_phaenc_ind[l] + output[s, ind, :] = kdat2[ind, :] + + output = np.reshape(output, [shots] + list(kdat1.shape)) + output = np.swapaxes(output, 1, phaenc_axis) + + return output diff --git a/sigpy/mri/sim.py b/sigpy/mri/sim.py index 42db136c..0d09e15a 100644 --- a/sigpy/mri/sim.py +++ b/sigpy/mri/sim.py @@ -2,8 +2,11 @@ """MRI simulation functions. """ import numpy as np +import math -__all__ = ["birdcage_maps"] +from sigpy.mri import epi + +__all__ = ['birdcage_maps', 'gradient_echoes', 'diffusion', 'get_subspace'] def birdcage_maps(shape, r=1.5, nzz=8, dtype=complex): @@ -56,3 +59,149 @@ def birdcage_maps(shape, r=1.5, nzz=8, dtype=complex): out /= rss return out.astype(dtype) + + +def _linspace_to_array(linspace_list): + val_start = linspace_list[0] + val_stop = linspace_list[1] + num = linspace_list[2] + + if num == 1: + return [val_start] + else: + step = (val_stop - val_start) / (num - 1) + return val_start + step * np.arange(num) + + +def gradient_echoes(TE, B0_linspace=(-50, 50, 101), + T2star_linspace=(0.001, 0.200, 100), + rho0_linspace=(0.01, 1, 100)): + """Compute gradient-echo signals + + signal = rho * exp(-TE/T2star) * exp(1i * 2 pi * B0 * TE) + + Args: + TE (array): echo times [second] + B0_linspace (list of floats): start, stop, num [Hz] + T2star_linspace (list of floats): start, stop, num [second] + rho0_linspace (list of floats): start, stop, num [a.u.] + + Returns: + sig (array) + + Author: + Zhengguo Tan + """ + B0_array = _linspace_to_array(B0_linspace) + T2star_array = _linspace_to_array(T2star_linspace) + rho0_array = _linspace_to_array(rho0_linspace) + + sig = np.zeros((len(TE), 1, 1, 1, len(B0_array), + len(T2star_array), len(rho0_array)), dtype=complex) + + for B0_ind in np.arange(B0_linspace[2]): + B0_val = B0_array[B0_ind] + + for T2star_ind in np.arange(T2star_linspace[2]): + T2star_val = T2star_array[T2star_ind] + z = (-1/T2star_val + 1j*2*math.pi*B0_val) + + for rho0_ind in np.arange(rho0_linspace[2]): + rho0_val = rho0_array[rho0_ind] + + sig[:, 0, 0, 0, B0_ind, T2star_ind, rho0_ind] = \ + rho0_val * np.exp(z * TE) + + return sig + + +def diffusion(b=None, g=None, D=(0, 0.004, 10)): + """Compute diffusion signals + + signal = exp(B * D) + + Args: + b (array): echo times [second] + g (array): start, stop, num [Hz] + + Returns: + sig (array) + + Author: + Zhengguo Tan + """ + assert (not ((b is None) ^ (g is None))) + + if b is None: + None + + B = epi.get_B(b, g) + + D_start = D[0] + D_end = D[1] + D_num = D[2] + + D_grid = np.linspace(D_start, D_end, D_num) + + D = np.meshgrid(D_grid, D_grid, D_grid, D_grid, D_grid, D_grid) + D = np.array(D).reshape((6, -1)) + + return np.exp(np.matmul(B, D)) + + +def get_subspace(sig, num_coeffs=5, error_bound=1e-5, prior_err=True): + """Compute linear subspace coefficients of MR signal. + + Args: + sig (array): MR signal + num_coeffs (int): expected number of coefficients + error_bound (float): relative error bound between the subspace signal + and the ground-truth signal + prior_err (boolean): iteratively increase the number of coefficients + in order to reach the error_bound + + Returns: + U_sub (array): truncated subspace matrix + + References: + Huang C., Graff C. G., Clarkson E. W., + Bilgin A., Altbach M. I. (2012). + T2 mapping from highly undersampled data by + reconstruction of principal component coefficient maps + using compressed sensing. + Magn. Reson. Med., 67, 1355-1366. + + Tamir J. I., Uecker M., Chen W., Lai P., Alley M. T., + Vasanawala S. S., Lustig M. (2017). + T2 shuffling: Sharp, multicontrast, volumetric fast spin-echo imaging. + Magn. Reson. Med., 77, 180-195. + + Author: + Zhengguo Tan + """ + def get_rel_error(recov_sig, full_sig): + return np.linalg.norm(full_sig - recov_sig).item() \ + / np.linalg.norm(full_sig).item() + + # reshape sig to (number of echoes, number of dictionary atoms) + sig2 = np.reshape(sig, (sig.shape[-7], -1)) + + # singular value decomposition + U, S, VH = np.linalg.svd(sig2, full_matrices=False) + + while True: + # truncate the U matrix + U_sub = U[:, :num_coeffs] + + # recover from U_sub + recov_sig = U_sub @ U_sub.T @ sig2 + + err = get_rel_error(recov_sig, sig2) + + if (err > error_bound) and (prior_err): + num_coeffs += 1 + else: + print('\nEventual number of subspace coefficients: ' + + str(num_coeffs)) + print('Eventual error: ' + str(err)) + return U_sub diff --git a/sigpy/mri/sms.py b/sigpy/mri/sms.py new file mode 100644 index 00000000..659c5ecb --- /dev/null +++ b/sigpy/mri/sms.py @@ -0,0 +1,318 @@ +"""Functions for Simultaneous Multi-Slice (SMS) Acquisition + +Author: Zhengguo Tan +""" +import numpy as np +import sigpy as sp + +__all__ = ['get_uncollap_slice_idx', + 'get_ordered_slice_idx', + 'reorder_slices', + 'get_sms_phase_shift', + 'readout_extended_fov', + 'readout_unextend_fov'] + + +def is_even(input): + return (input % 2 ==0) + + +def map_acquire_to_ordered_slice_idx(acq_slice_idx, + N_slices_uncollap, N_band, + verbose=False): + """ + Map the acquired slice index to ordered uncollapsed slice indices. + + Args: + acq_slice_idx (int): + acquired slice index + N_slices_uncollap (int): + total number of uncollapsed slices + N_band (int): + multi-band factor + + Output: + ordered uncollapsed slice indices (list of int) + """ + N_slices_collap = N_slices_uncollap // N_band + N_slices_collap_half = N_slices_collap // 2 + + ord_slice_idx = [] + for b in range(N_band): + + # interleaved slice order + if (acq_slice_idx >= N_slices_collap_half) and \ + (is_even(N_slices_collap)): + so = acq_slice_idx * 2 + else: + so = acq_slice_idx * 2 + 1 + + so = (so + b * N_slices_collap) % N_slices_uncollap + ord_slice_idx.append(so) + + # need this when MB=3 and N_slices_uncollap=141 + # if (acq_slice_idx != N_slices_collap_half) or (N_band % 2 == 0): + # and (N_band % 2) and (N_slices_uncollap % 2) + ord_slice_idx.sort() + + if verbose is True: + print('acquired slice: ' + str(acq_slice_idx).zfill(3) + + ' --> ordered slices: ' + str(ord_slice_idx)) + + return ord_slice_idx + + +def get_uncollap_slice_idx(N_slices_uncollap, MB, collap_slice_idx): + """ + Get uncollapsed slice indices for the collapsed slice "collap_slice_idx". + + Args: + N_slices_uncollap (int): total number of uncollapsed slices. + MB (int): multi-band factor. + collap_slice_idx (int): collapsed slice index. + + e.g.: + slice_idx = get_uncollap_slice_idx(30, 2, 0) + returns [0, 7] + """ + slice_idx = [] + N_slices_collap = N_slices_uncollap // MB + + if (collap_slice_idx < 0) or (collap_slice_idx >= N_slices_collap): + raise ValueError('collap_slice_idx must be in the range: [0, ' + + str(N_slices_collap) + ').') + + N_slices_collap_half = N_slices_collap // 2 + + N_slices_1 = N_slices_collap_half * MB + N_slices_2 = N_slices_uncollap - N_slices_1 + + for b in range(MB): + + if collap_slice_idx < N_slices_collap_half: + + N_0 = collap_slice_idx + N_step = N_slices_1 // MB + + else: + + N_0 = N_slices_1 + collap_slice_idx - N_slices_collap_half + N_step = N_slices_2 // MB + + slice_idx.append(N_0 + b * N_step) + + return list(slice_idx) + + +def get_ordered_slice_idx(acq_slice_idx, N_slices): + """ + Get the ordered (geometrically correct) slice indices + from the acquisition slice order. + + Args: + acq_slice_idx (int or tuple of ints): indices for acquired slices. + N_slices (int): total number of slices. + + Output: + ordered slice indices (list). + """ + ordered_slice_idx = list(range(1, N_slices, 2)) \ + + list(range(0, N_slices, 2)) + + if isinstance(acq_slice_idx, int): + return ordered_slice_idx[acq_slice_idx] + else: + return [ordered_slice_idx[i] for i in list(acq_slice_idx)] + + +def reorder_slices_mb1(input, N_slices, slice_axis=-3): + + input = np.swapaxes(input, 0, slice_axis) + + img_shape = input.shape[1:] + + output = np.zeros_like(input, shape=[N_slices] + list(img_shape)) + + for s in range(N_slices): + ord_slice_idx = get_ordered_slice_idx(s, N_slices) + + print('acquired slice: ' + str(s).zfill(3) + + ' --> geometric slice: ' + str(ord_slice_idx).zfill(3)) + + output[ord_slice_idx, ...] = input[s, ...] + + output = np.swapaxes(output, 0, slice_axis) + + return output + + +def reorder_slices_mbx(input, N_band, N_slices, band_axis=-3, slice_axis=0): + + assert (N_band == input.shape[band_axis]) + + N_slices_collap = N_slices // N_band + + assert (N_slices_collap == input.shape[slice_axis]) + + # swap axes such that slice stored in axis 0, and band in axis 1 + input = np.swapaxes(input, 1, band_axis) + input = np.swapaxes(input, 0, slice_axis) + + img_shape = input.shape[2:] # image shape excluding slices and bands + + output = np.zeros_like(input, shape=[N_slices] + list(img_shape)) + + for s in range(N_slices_collap): + + slice_mb_idx = map_acquire_to_ordered_slice_idx(s, N_slices, N_band, + verbose=True) + + output[slice_mb_idx, ...] = input[s, ...] + + output = np.swapaxes(output, 0, -3) + + return output + + +def reorder_slices(input, N_band, N_slices, band_axis=-3, slice_axis=0): + """ + reorder slices after SMS image reconstruction. + + Requirement: + band and slice must be stored in different axes + + Args: + + Output: + + """ + assert (N_band == input.shape[band_axis]) + + N_slices_collap = N_slices // N_band + + assert (N_slices_collap == input.shape[slice_axis]) + + # swap axes such that slice stored in axis 0, and band in axis 1 + input = np.swapaxes(input, 1, band_axis) + input = np.swapaxes(input, 0, slice_axis) + + img_shape = input.shape[2:] # image shape excluding slices and bands + + output = np.zeros_like(input, shape=[N_slices] + list(img_shape)) + + for s in range(N_slices_collap): + acq_slice_idx = get_uncollap_slice_idx(N_slices, N_band, s) + ord_slice_idx = get_ordered_slice_idx(acq_slice_idx, N_slices) + + print('collapsed slice: ' + str(s).zfill(3) + + ' --> acquired slices: ' + + str([str(sid).zfill(3) for sid in acq_slice_idx]) + + ' --> reordered slices: ' + + str([str(sid).zfill(3) for sid in ord_slice_idx])) + + output[ord_slice_idx, ...] = input[s, ...] + + # slice stored in axis -3 + output = np.swapaxes(output, 0, -3) + + return output + + +def get_sms_phase_shift(ishape, MB, yshift=None): + """ + Args: + ishape (tuple or list): input shape of [..., Nz, Ny, Nx]. + MB (int): multi-band factor. + yshift (tuple or list): use custom yshift. + + References: + * Breuer FA, Blaimer M, Heidemann RM, Mueller MF, + Griswold MA, Jakob PM. + Controlled aliasing in parallel imagin results in + higher acceleration (CAIPIRINHA) for multi-slice imaging. + Magn. Reson. Med. 53:684-691 (2005). + """ + Nz, Ny, Nx = ishape[-3:] + + phi = np.ones(ishape, dtype=complex) + + bas = 2 * np.pi / 2 + + if yshift is None: + yshift = (np.arange(Nz)) / MB + else: + assert (len(yshift) == Nz) + + print(' > sms: yshift ', yshift) + + lx = np.arange(Nx) - Nx // 2 + ly = np.arange(Ny) - Ny // 2 + mx, my = np.meshgrid(lx, ly) + + for z in range(Nz): + slice_yshift = bas * yshift[z] + + phi[..., z, :, :] = np.exp(1j * my * slice_yshift) + + return phi + + +def readout_extended_fov(ksp, mps, MB): + """ + References: + * Koopmans PJ, Poser BA, Breuer FA. + 2D-SENSE-GRAPPA for fast, ghosting-robust reconstruction of + in-plane and slice accelerated blipped-CAIPI-EPI. + Proc. Intl. Soc. Magn. Reson. Med. 2015, page 2410. + """ + Ncoil, Ny, Nx = ksp.shape[-3:] + _Ncoil, Nz, _Ny, _Nx = mps.shape[-4:] + + assert ((Ncoil == _Ncoil) and (Nz == MB) and (Ny == _Ny) and (Nx == _Nx)) + + msk = (sp.rss(ksp, axes=(-3, ), keepdims=True) > 0.).astype(ksp.dtype) + + ksp_ext = np.zeros(list(ksp.shape[:-3]) + [Ncoil] + [Ny] + [Nx * MB], + dtype=ksp.dtype) + msk_ext = np.zeros(list(ksp.shape[:-3]) + [1] + [Ny] + [Nx * MB], + dtype=ksp.dtype) + mps_ext = np.zeros(list(mps.shape[:-3]) + [Ny] + [Nx * MB], + dtype=mps.dtype) + + for b in range(MB): # loop over multi bands + + x_idx = np.arange(b*Nx, (b+1)*Nx, 1) + mps_ext[..., x_idx] = mps[..., b, :, :] + ksp_ext[..., x_idx] = sp.ifft(ksp, axes=[-2, -1]) + + k_idx = np.arange(0, Nx * MB, MB) + msk_ext[..., k_idx] = msk + # ksp_ext[..., k_idx] = ksp + + ksp_ext = msk_ext * sp.fft(ksp_ext, axes=[-2, -1]) + + return ksp_ext, mps_ext, msk_ext + + +def readout_unextend_fov(input, MB): + """unextend readout extended FOV images. + + Args: + input (array): input image of shape [..., Ny, Nx * MB]. + MB (int): multi-band factor. + + Output: + image array of shape [..., MB, Ny, Nx]. + """ + Ny, Nx_ext = input.shape[-2:] + + Nx = int(Nx_ext // MB) + + output = np.zeros(list(input.shape[:-2]) + [MB] + [Ny] + [Nx], + dtype=input.dtype) + + for b in range(MB): + x_idx = np.arange(b*Nx, (b+1)*Nx, 1) + output[..., b, :, :] = input[..., x_idx] + + return output diff --git a/sigpy/nlls.py b/sigpy/nlls.py new file mode 100644 index 00000000..15d16e4e --- /dev/null +++ b/sigpy/nlls.py @@ -0,0 +1,192 @@ +""" +This module implements non-linear least square (nlls) algorithms. +Author: Zhengguo Tan +""" +import numpy as np +import sigpy as sp + +from sigpy import backend, alg, prox, util + +class NonLinearLeastSquares(alg.Alg): + """Abstraction for non-linear least square (nlls) algorithms. + + Args: + A (nlop): non-linear operator. + y (array): measurements. + x (None or array): initial guess of unknown. Default None. + x0 (None or array): previous estimate of unknown. Default None. + outer_iter (int): outer iteration steps. Default 6. + alpha (float): regularization strength. Default 1. + redu (float): reduction rate along outer iteration. Default 2. + trafos (None or linop): transformation on x. Default None. + proxf (None or prox): proximal operator. Default None. + inner_iter (int): inner iteration steps. Default 100. + inner_tol (float): inner iteration tolerance. Default 0.01. + + References: + Bauer F., Kannengiesser S. (2007). + An alternative approach to the image reconstruction for parallel data acquisition in MRI. + Math. Meth. Appl. Sci. 30, 1437-1451. + + Uecker M., Hohage T., Block K. T., Frahm J. (2008) + Image reconstruction by regularized nonlinear inversion -- joint estimation of coil sensitivities and image content. + Magn. Reson. Med. 60, 674-682. + + Tan Z., Roeloffs V., Voit D., Joseph A. A., Untenberger M., Merboldt K. D., Frahm J. (2016). + Model-based reconstruction for real-time phase-contrast flow MRI: Improved spatiotemporal accuracy. + Magn. Reson. Med. 77, 1082-1093. + """ + def __init__(self, A, y, x=None, x0=None, + outer_iter=6, alpha=1., redu=2., + trafos=None, proxf=None, + inner_iter=100, inner_tol=0.01): + self.A = A + self.y = y + self.device = backend.get_device(y) + xp = self.device.xp + + self.trafos = trafos + self.proxf = proxf + + # outer iteration + self.outer_iter = outer_iter # max_iter + self.alpha = alpha + self.redu = redu + + # inner iteration + self.inner_iter = inner_iter + self.inner_tol = inner_tol + + with self.device: + # initialize x + self.x = xp.zeros(A.ishape, dtype=y.dtype) + if x is not None: + self.x = sp.to_device(x, device=self.device) + + # initialize x0 + self.x0 = xp.zeros(A.ishape, dtype=y.dtype) + if x0 is not None: + self.x0 = sp.to_device(x0, device=self.device) + + # FIXME: splitx + if self.A.repr_str == 'Nlinv': + self.splitx = 1 + elif self.A.repr_str == 'Diffusion': + self.splitx = self.A.ishape[0] + + super().__init__(outer_iter) + + def _update(self): + """Perfom outer iteration steps of the nonlinear problem. + """ + xp = self.device.xp + + with self.device: + self.dx = xp.zeros_like(self.x) + + self.r = self.y - self.A(self.x) + + # residual + resid = xp.linalg.norm(self.r).item() + + print("iter: " + "%2d"%(self.iter) + "; alpha: " + "%.2f"%(self.alpha) + "; resid: " + "%4.3f"%(resid)) + + self.p = self.A.adjoint(self.x, self.r) + + self.p += self.alpha * (self.x0 - self.x) + + # update dx + self.lls() + + self.x += 1. * self.dx + + self.alpha /= self.redu + + + def _done(self): + """Stopping criteria for the outer iteration. + """ + return self.iter >= self.outer_iter + + + def lls(self): + """Inner linear least square (lls) solver + for the linearized non-linear problem. + """ + xp = self.device.xp + + with self.device: + def AHA(x): + return self.A.adjoint(self.x, self.A.derivative(self.x, x)) + + # tolerance + tol = self.inner_tol * xp.linalg.norm(self.p).item() + + if self.trafos is None and self.proxf is None: + """Conjugate gradient method + """ + + AHA_L2 = lambda x: AHA(x) + self.alpha * x + alg_lls = alg.ConjugateGradient(AHA_L2, self.p, self.dx, max_iter=self.inner_iter, tol=tol) + while not alg_lls.done(): + alg_lls.update() + + else: + """Primal dual method + """ + einit = util.randn(self.x.shape, + dtype=self.x.dtype, + device=self.device) + + alg_eig = alg.PowerMethod(AHA, einit, max_iter=10) + while not alg_eig.done(): + alg_eig.update() + + sigma = 1 + tau = 1 / alg_eig.max_eig + theta = 1 + self.inner_tol *= self.alpha + inner_iter = int(min(self.inner_iter, + 10 * 2**(np.log(1/self.alpha)))) + + proxf1c = prox.L2Reg(self.y.shape, 1, y=-self.r) + + self.proxf.lamda *= self.alpha + proxf2c = prox.Conj(self.proxf) + + proxg = prox.L2Reg(self.x.shape, self.alpha, y=self.x0 - self.x) + + + # initialization + dx_ext = self.x.copy() + u_F1 = xp.zeros_like(self.y) + + u_F2 = xp.zeros_like(self.y, shape=proxf2c.shape) + + for ninner in range(inner_iter): + + # Update dual 1 + util.axpy(u_F1, sigma, self.A.derivative(self.x, dx_ext)) + backend.copyto(u_F1, proxf1c(sigma, u_F1)) + + # Update dual 2 + util.axpy(u_F2, sigma, self.trafos(dx_ext[:self.splitx, ...])) + backend.copyto(u_F2, proxf2c(sigma, u_F2)) + + # Update primal 1 + dx_old = self.dx.copy() + util.axpy(self.dx, -tau, self.A.adjoint(self.x, u_F1)) + util.axpy(self.dx[:self.splitx, ...], tau, self.trafos.H(u_F2)) + backend.copyto(self.dx, proxg(tau, self.dx)) + + dx_dif = self.dx - dx_old + self.inner_resid = xp.linalg.norm(dx_dif / tau**0.5).item() + backend.copyto(dx_ext, self.dx + theta * dx_dif) + + print(" iter: " + "%3d"%(ninner) + + "; resid: " + "%4.6f"%(self.inner_resid)) + + if self.inner_resid < self.inner_tol: + break + + return self.dx diff --git a/sigpy/nlop.py b/sigpy/nlop.py new file mode 100644 index 00000000..f46830ec --- /dev/null +++ b/sigpy/nlop.py @@ -0,0 +1,369 @@ +# -*- coding: utf-8 -*- +"""This module contains an abstraction class Nlop for non-linear operators, +and provides commonly used non-linear operator, including exponential. + +Author: Zhengguo Tan +""" +import numpy as np + +from sigpy import backend, linop + + +class Nlop(): + """Abstraction for non-linear operator. + + Given a nlop A, and an appropriately shaped input x, + the following are valid operations: + + >>> y = A.forward(x) # apply forward operation A on x + >>> y = A(x) + >>> y = A * x + >>> y = A.derivative(x, dx) # apply Jacobian (derivative) of A on x + >>> x = A.adjoint(x, dy) # apply adjoint of derivative on y + + Args: + oshape (tuple): operator output shape. + ishape (tuple): operator input shape. + scale (array or None): Scaling for the input + + Attributes: + oshape: output shape. + ishape: input shape. + forward: apply forward operator. + get_Jacobian: get the operator's Jacobian matrix. + derivative: apply derivative operator. + adjoint: apply adjoint of derivative operator. + + """ + def __init__(self, oshape, ishape, + scale=None, repr_str=None): + self.oshape = oshape + self.ishape = ishape + self.scale = scale + + linop._check_shape_positive(oshape) + linop._check_shape_positive(ishape) + + if repr_str is None: + self.repr_str = self.__class__.__name__ + else: + self.repr_str = repr_str + + def _check_ishape(self, input): + for i1, i2 in zip(input.shape, self.ishape): + if i2 != -1 and i1 != i2: + raise ValueError( + 'input shape mismatch for {s}, got {input_shape}'.format( + s=self, input_shape=input.shape)) + + def _check_oshape(self, output): + for o1, o2 in zip(output.shape, self.oshape): + if o2 != -1 and o1 != o2: + raise ValueError( + 'output shape mismatch for {s}, got {output_shape}'.format( + s=self, output_shape=output.shape)) + + def _get_Jacobian(self, x): + raise NotImplementedError + + def get_Jacobian(self, x): + """Compute the Jacobian matrix of non-linear operator. + """ + try: + self._check_ishape(x) + output = self._get_Jacobian(x) + except Exception as e: + raise RuntimeError('Exceptions from {}.'.format(self)) from e + + return output + + def _forward(self, input): + raise NotImplementedError + + def forward(self, input): + """Apply non-linear forward operation on input. + """ + try: + self._check_ishape(input) + output = self._forward(input) + self._check_oshape(output) + except Exception as e: + raise RuntimeError('Exceptions from {}.'.format(self)) from e + + return output + + def _derivative(self, x, dx): + raise NotImplementedError + + def derivative(self, x, dx): + """Apply derivative operation on input. + """ + try: + self._check_ishape(x) + self._check_ishape(dx) + output = self._derivative(x, dx) + self._check_oshape(output) + except Exception as e: + raise RuntimeError('Exceptions from {}.'.format(self)) from e + + return output + + def _adjoint(self, x, dy): + raise NotImplementedError + + def adjoint(self, x, dy): + """Apply adjoint of derivative operation on input. + """ + try: + self._check_oshape(dy) + self._check_ishape(x) + output = self._adjoint(x, dy) + self._check_ishape(output) + except Exception as e: + raise RuntimeError('Exceptions from {}.'.format(self)) from e + + return output + + def __call__(self, input): + return self.__mul__(input) + + def __mul__(self, input): + if isinstance(input, linop.Linop): + return Compose([self, input]) + elif isinstance(input, backend.get_array_module(input).ndarray): + return self.forward(input) + + return NotImplemented + + +class Exponential(Nlop): + """ + Construction of the non-linear exponential operator. + + Given the unknown x = (b, a, R)^T, where + b: bias array, scalar, or None, + a: encoding array or scalar, and + R: relaxation rate, + and the encoding array encode, + the forware operation is + F(x) = b + a * exp(encode * R). + + Args: + ishape (tuple): input shape + encode (array): echo times or B matrix + bias (boolean): have bias (b) in the model or not + """ + def __init__(self, ishape, encode, + bias=False, const_a=False, + scale=None, rvc=False, repr_str=None): + image_shape = ishape[1:] + num_param = ishape[0] + + num_encode, num_relax = encode.shape + + assert num_relax == num_param - 2 if bias else num_param - 1 + + self.encode = encode + self.bias = bias + self.const_a = const_a + self.rvc = rvc + + if scale is None: + scale = np.ones([num_param] + [1] * len(image_shape)) + else: + assert scale.shape[0] == num_param + + oshape = [num_encode] + list(image_shape) + + super().__init__(oshape, ishape, scale=scale, repr_str=repr_str) + + def _check_coil_image_shape(self, coil_shape, image_shape): + for i1, i2 in zip(coil_shape, image_shape): + if i1 != i2: + raise ValueError('coils and image have different shape.') + + def get_params(self, x): + """Split the unknown x into (b, a, R) + """ + xp = backend.get_device(x).xp + + with backend.get_device(x): + xscale = self.scale * x + ind = 0 + + if self.bias is True: + b = xscale[ind, ...] + ind += 1 + else: + b = None + + if self.const_a is True: + a = xp.ones_like(xscale, shape=xscale.shape[1:]) + else: + a = xscale[ind, ...] + ind += 1 + + R = xscale[ind:, ...] + + return b, a, R + + def _forward(self, input): + device = backend.get_device(input) + xp = device.xp + + with device: + self.encode = backend.to_device(self.encode, device=device) + self.scale = backend.to_device(self.scale, device=device) + + self.x = input + + b, a, R = self.get_params(self.x) + + Rr = xp.reshape(R, (R.shape[0], -1)) + output = xp.exp(xp.matmul(self.encode, Rr)) + output = xp.reshape(output, self.oshape) + + output *= a + + if self.bias is True: + output += b + + return output + + def _get_Jacobian(self, x): + # For the computation of Jacobian, self.x must exist. + self.x = x + device = backend.get_device(self.x) + xp = device.xp + + image_shape = self.x.shape[1:] + + with device: + self.encode = backend.to_device(self.encode, device=device) + + b, a, R = self.get_params(self.x) + + Jshape = [] # Nr. of encoding + xshape + Jshape.append(self.encode.shape[0]) + Jshape += list(self.x.shape) + + output = xp.zeros_like(self.x, shape=Jshape) + + ind = 0 + + # Jacobian for b + if self.bias is True: + output[:, ind, ...] = xp.ones_like( + self.x, shape=image_shape) + ind += 1 + + Rr = xp.reshape(R, (R.shape[0], -1)) + z = xp.exp(xp.matmul(self.encode, Rr)) + z = xp.reshape(z, self.oshape) + + # Jacobian for a + if self.const_a is False: + output[:, ind, ...] = z + ind += 1 + + # Jacobian for R + encode = xp.reshape(self.encode, list(self.encode.shape) + + [1] * len(image_shape)) + Z = xp.reshape(a * z, [z.shape[0]] + [1] + list(image_shape)) + output[:, ind:, ...] = encode * Z + + return output + + def _derivative(self, x, dx): + device = backend.get_device(dx) + xp = device.xp + + with device: + self.Jacobian = self._get_Jacobian(x) + return xp.sum(self.Jacobian * dx, axis=1) + + def _adjoint(self, x, dy): + device = backend.get_device(dy) + xp = device.xp + + with device: + self.Jacobian = self._get_Jacobian(x) + JH = xp.conjugate(xp.moveaxis(self.Jacobian, 0, 1)) + dx = xp.sum(JH * dy, axis=1) + + if self.rvc: + dx = dx.real + 0. * 1j + + return dx + + +def _check_compose_nlops(ops): + for op1, op2 in zip(ops[:-1], ops[1:]): + if (op1.ishape != op2.oshape): + raise ValueError('cannot compose {op1} and {op2}.'.format( + op1=op1, op2=op2)) + + +def _combine_compose_nlops(ops): + combined_nlops = [] + for op in ops: + if isinstance(op, Compose): + combined_nlops += op.ops + else: + combined_nlops.append(op) + + return combined_nlops + + +class Compose(Nlop): + """Composition of non-linear operators. + + Args: + ops (list of operators): (linops and/or nlops) to be composed. + + Returns: + Nlop: op[0] * op[1] * ... * op[n - 1] + """ + def __init__(self, ops): + _check_compose_nlops(ops) + self.ops = _combine_compose_nlops(ops) + + super().__init__( + self.ops[0].oshape, self.ops[-1].ishape, + repr_str=' * '.join([op.repr_str for op in ops])) + + def _forward(self, input): + output = input + for op in self.ops[::-1]: + output = op(output) + return output + + def _get_Jacobian(self, input): + output = input + for op in self.ops[::-1]: + if isinstance(op, linop.Linop): + output = op(output) + elif isinstance(op, Nlop): + output = op.get_Jacobian(output) + + return output + + def _derivative(self, x, dx): + output = dx + for op in self.ops[::-1]: + if isinstance(op, linop.Linop): + output = op(output) + elif isinstance(op, Nlop): + output = op.derivative(x, output) + + return output + + def _adjoint(self, x, dy): + output = dy + for op in self.ops[::1]: + if isinstance(op, linop.Linop): + output = op.H(output) + elif isinstance(op, Nlop): + output = op.adjoint(x, output) + + return output diff --git a/sigpy/nn/torch_dce.py b/sigpy/nn/torch_dce.py new file mode 100644 index 00000000..c00b06df --- /dev/null +++ b/sigpy/nn/torch_dce.py @@ -0,0 +1,92 @@ +import torch + +import numpy as np +import torch.nn as nn +import torch.optim as optim + +from sigpy.mri import dce + +# %% +class DCE(nn.Module): + def __init__(self, + ishape, + sample_time, + R1 = 1., + M0 = 5., + R1CA = 4.39, + FA = 15., + TR = 0.006): + super(DCE, self).__init__() + + self.ishape = list(ishape) + + self.sample_time = torch.tensor(np.squeeze(sample_time), dtype=torch.float32) + + self.R1 = torch.tensor(np.array(R1), dtype=torch.float32) + self.M0 = torch.tensor(np.array(M0), dtype=torch.float32) + self.R1CA = torch.tensor(np.array(R1CA), dtype=torch.float32) + self.FA = torch.tensor(np.array(FA), dtype=torch.float32) + self.TR = torch.tensor(np.array(TR), dtype=torch.float32) + + self.FA_radian = self.FA * np.pi / 180. + self.M0_trans = self.M0 * torch.sin(self.FA_radian) + + E1 = torch.exp(-self.TR * self.R1) + self.M_steady = self.M0_trans * (1 - E1) / (1 - E1 * torch.cos(self.FA_radian)) + + Cp = dce.arterial_input_function(sample_time) + self.Cp = torch.tensor(Cp, dtype=torch.float32) + + def _check_ishape(self, input): + for i1, i2 in zip(input.shape, self.ishape): + if i1 != i2: + raise ValueError( + 'input shape mismatch for {s}, got {input_shape}'.format(s=self, input_shape=input.shape)) + + def _param_to_conc(self, x): + t1_idx = torch.nonzero(self.sample_time) + t1 = self.sample_time[t1_idx] + dt = torch.diff(t1, dim=0) + K_time = torch.cumsum(self.Cp, dim=0) * dt[-1] + + mult = torch.stack((K_time, self.Cp), 1) + + xr = torch.reshape(x, (self.ishape[0], np.prod(self.ishape[1:]))) + + yr = torch.matmul(mult, xr) + + oshape = [len(self.sample_time)] + self.ishape[1:] + yr = torch.reshape(yr, tuple(oshape)) + + return yr + + def forward(self, x): + + if torch.is_tensor(x) is not True: + x = torch.tensor(x, dtype=torch.float32) + + self._check_ishape(x) + + # parameters (k_trans, v_p) to concentration + CA = self._param_to_conc(x) + x0 = CA[0, ...] # baseline image + + # concentration to MR signal + E1CA = torch.exp(-self.TR * (self.R1 + self.R1CA * CA)) + + CA_trans = self.M0_trans * (1 - E1CA) / (1 - E1CA * torch.cos(self.FA_radian)) + + y = CA_trans + x0 - self.M_steady + + return y + +# %% +if torch.cuda.is_available(): + device = "cuda:0" +else: + device = "cpu" + +model = DCE() + + +# for epoch in range(20): diff --git a/sigpy/prox.py b/sigpy/prox.py index d80dd1fd..aed4bd41 100644 --- a/sigpy/prox.py +++ b/sigpy/prox.py @@ -4,8 +4,9 @@ l1 ball projection, and box constraints. """ import numpy as np +import random -from sigpy import backend, thresh, util +from sigpy import backend, util, thresh, linop class Prox(object): @@ -84,7 +85,8 @@ def _prox(self, alpha, input): class NoOp(Prox): - r"""Proximal operator for empty function. Equivalant to an identity function. + r"""Proximal operator for empty function. + Equivalant to an identity function. Args: shape (tuple of ints): Input shape @@ -108,7 +110,8 @@ class Stack(Prox): def __init__(self, proxs): self.nops = len(proxs) - assert self.nops > 0 + + assert (self.nops > 0) self.proxs = proxs self.shapes = [prox.shape for prox in proxs] @@ -329,3 +332,233 @@ def _prox(self, alpha, input): with device: return xp.clip(input, self.lower, self.upper) + + +class LLRL1Reg(Prox): + r"""Local Low Rank L1 Regularization + + Args: + shape (tuple of int): input shapes. + lamda (float): regularization parameter. + randshift (boolean): switch on random shift or not. + blk_shape (tuple of int): block shape [default: (8, 8)]. + blk_strides (tuple of int): block strides [default: (8, 8)]. + + References: + * Cai JF, Candes EJ, Shen Z. + A singular value thresholding algorithm + for matrix completion. + SIAM J Optim 20:1956-1982 (2010). + + * Trzasko J, Manduca A. + Local versus global low-rank promotion + in dynamic MRI series reconstruction. + Proc. ISMRM 19:4371 (2011). + + * Zhang T, Pauly J, Levesque I. + Accelerating parameter mapping with a locally low rank constraint. + Magn Reson Med 73:655-661 (2015). + + * Saucedo A, Lefkimmiatis S, Rangwala N, Sung K. + Improved computational efficiency of locally low rank + MRI reconstruction using iterative random patch adjustments. + IEEE Trans Med Imaging 36:1209-1220 (2017). + + * Hu Y, Wang X, Tian Q, Yang G, Daniel B, McNab J, Hargreaves B. + Multi-shot diffusion-weighted MRI reconstruction + with magnitude-based + spatial-angular locally low-rank regularization (SPA-LLR). + Magn Reson Med 83:1596-1607 (2020). + + Author: + Zhengguo Tan + """ + + def __init__(self, shape, lamda, randshift=True, + blk_shape=(8, 8), blk_strides=(8, 8), + reg_magnitude=False, + normalization=False, + verbose=False): + self.lamda = lamda + self.randshift = randshift + self.reg_magnitude = reg_magnitude + self.normalization = normalization + + assert len(blk_shape) == len(blk_strides) + self.blk_shape = blk_shape + self.blk_strides = blk_strides + self.verbose = verbose + + # construct forward linops + self.RandShift = self._linop_randshift(shape, blk_shape, randshift) + self.A = linop.ArrayToBlocks(shape, blk_shape, blk_strides) + self.Reshape = self._linop_reshape() + + self.Fwd = self.Reshape * self.A * self.RandShift + + super().__init__(shape) + + def _check_blk(self): + assert len(self.blk_shape) == len(self.blk_strides) + + def _prox(self, alpha, input): + device = backend.get_device(input) + xp = device.xp + + with device: + + if self.reg_magnitude: + mag = xp.abs(input) + phs = xp.exp(1j * xp.angle(input)) + + else: + mag = input.copy() + phs = xp.ones_like(mag) + + output = self.Fwd(mag) + + u, s, vh = xp.linalg.svd(output, full_matrices=False) + + if self.normalization is True: + s = s / self.blk_shape[-1] + + s_thresh = thresh.soft_thresh(self.lamda * alpha, s) + + if self.normalization is True: + s_thresh = s_thresh * self.blk_shape[-1] + + output = (u * s_thresh[..., None, :]) @ vh + + output = self.Fwd.H(output) + + return output * phs + + def _linop_randshift(self, shape, blk_shape, randshift): + + D = len(blk_shape) + + if randshift is True: + axes = range(-D, 0) + shift = [random.randint(0, blk_shape[s]) for s in axes] + + return linop.Circshift(shape, shift, axes) + else: + return linop.Identity(shape) + + def _linop_reshape(self): + D = len(self.blk_shape) + + oshape = [util.prod(self.A.ishape[:-D]), + util.prod(self.A.num_blks), + util.prod(self.blk_shape)] + + R1 = linop.Reshape(oshape, self.A.oshape) + R2 = linop.Transpose(R1.oshape, axes=(1, 0, 2)) + return R2 * R1 + + +class SLRMCReg(Prox): + r"""Structure Low Rank Matrix Completion as Regularization + + Args: + shape (tuple of int): input shapes. + lamda (float): regularization parameter. + blk_shape (tuple of int): block shape [default: (7, 7)]. + blk_strides (tuple of int): block strides [default: (1, 1)]. + thresh (string): thresholding type ['soft' or 'hard']. + + References: + * Mani M, Jacob M, Kelley D, Magnotta V. + Multi-shot sensitivity-encoded diffusion data recovery using + structured low-rank matrix completion (MUSSELS). + Magn Reson Med 78:494-507 (2017). + + * Bilgic B, Chatnuntawech I, Manhard MK, Tian Q, + Liao C, Iyer SS, Cauley SF, Huang SY, + Polimeni JR, Wald LL, Setsompop K. + Highly accelerated multishot echo planar imaging through + synergistic machine learning and joint reconstruction. + Magn Reson Med 82:1343-1358 (2019). + + * Dai E, Mani M, McNab JA. + Multi-band multi-shot diffusion MRI reconstruction with + joint usage of structured low-rank constraints + and explicit phase mapping. + Magn Reson Med 89:95-111 (2023). + + Author: + Zhengguo Tan + """ + def __init__(self, shape, lamda, + blk_shape=(7, 7), blk_strides=(1, 1), + thresh='hard', verbose=False): + self.lamda = lamda + + assert len(blk_shape) == len(blk_strides) + self.blk_shape = blk_shape + self.blk_strides = blk_strides + self.thresh = thresh + self.verbose = verbose + + # construct forward linops + self.A = linop.ArrayToBlocks(shape, blk_shape, blk_strides) + self.Reshape = self._linop_reshape() + + self.Fwd = self.Reshape * self.A + + super().__init__(shape) + + def _prox(self, alpha, input): + device = backend.get_device(input) + xp = device.xp + + with device: + + output = self.Fwd(input) + + # SVD + u, s, vh = xp.linalg.svd(output, full_matrices=False) + + if self.thresh == 'soft': # soft thresholding + + s_thresh = thresh.soft_thresh(self.lamda * alpha, s) + + output = (u * s_thresh[..., None, :]) @ vh + + else: # hard thresholding + + keep = int(self.lamda * alpha * len(s)) + + if keep >= len(s): + keep = len(s) + + if self.verbose: + print('>>> shape of the array for SVD: ', output.shape) + print('>>> # of singular values kept ' + str(keep) + + ' of ' + str(len(s))) + + u_t, s_t, vh_t = u[..., :keep], s[:keep], vh[..., :keep, :] + + output = (u_t * s_t[..., None, :]) @ vh_t + + output = self.Fwd.H(output) + + return output + + def _linop_reshape(self): + D = len(self.blk_shape) + + oshape1 = [util.prod(self.A.ishape[:-D]), + util.prod(self.A.num_blks), + util.prod(self.blk_shape)] + + R1 = linop.Reshape(oshape1, self.A.oshape) + R2 = linop.Transpose(R1.oshape, axes=(0, 2, 1)) + + oshape2 = [util.prod(R2.oshape[:-1]), + R2.oshape[-1]] + + R3 = linop.Reshape(oshape2, R2.oshape) + R4 = linop.Transpose(R3.oshape, axes=(1, 0)) + + return R4 * R3 * R2 * R1 diff --git a/sigpy/pytorch.py b/sigpy/pytorch.py index 45b7683c..703161ea 100644 --- a/sigpy/pytorch.py +++ b/sigpy/pytorch.py @@ -27,6 +27,11 @@ def to_pytorch(array, requires_grad=True): # pragma: no cover from torch.utils.dlpack import from_dlpack device = backend.get_device(array) + + if not array.flags.c_contiguous: + with device: + array = device.xp.ascontiguousarray(array) + if not np.issubdtype(array.dtype, np.floating): with device: shape = array.shape diff --git a/sigpy/sim.py b/sigpy/sim.py index 4affeb0d..7dfa65d5 100644 --- a/sigpy/sim.py +++ b/sigpy/sim.py @@ -21,6 +21,39 @@ def shepp_logan(shape, dtype=complex): return phantom(shape, sl_amps, sl_scales, sl_offsets, sl_angles, dtype) +def dynamic_shepp_logan(shape, dtype=complex, + dynamic_fun='sin', + dynamic_scale=4): + """Generates a moving Shepp Logan phantom + + Args: + shape (tuple of ints): shape, can be of length 3 or 4. + The first dimension is the motion dimension. + dtype (Dtype): data type. + + Reuturns: + array. + + Author: + Zhengguo Tan + """ + + N_frame = shape[0] + + if dynamic_fun == 'sin': + motion = 0.5 * np.sin(2 * np.pi * np.arange(N_frame) / N_frame) + 1 + + output = [] + for f in range(N_frame): + sl_scales_f = np.array(sl_scales) + sl_scales_f[dynamic_scale] *= motion[f] + sl_scales_f = sl_scales_f.tolist() + output.append(phantom(shape[1:], sl_amps, sl_scales_f, + sl_offsets, sl_angles, dtype)) + + return np.array(output) + + sl_amps = [1, -0.8, -0.2, -0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] sl_scales = [ diff --git a/sigpy/util.py b/sigpy/util.py index dd6fe4cd..d4a4591b 100644 --- a/sigpy/util.py +++ b/sigpy/util.py @@ -103,7 +103,7 @@ def split(vec, oshapes): return outputs -def rss(input, axes=(0,)): +def rss(input, axes=(0, ), keepdims=False): """Root sum of squares. Args: @@ -114,7 +114,8 @@ def rss(input, axes=(0,)): array: Result. """ xp = backend.get_array_module(input) - return xp.sum(xp.abs(input) ** 2, axis=axes) ** 0.5 + + return xp.sum(xp.abs(input)**2, axis=axes, keepdims=keepdims)**0.5 def resize(input, oshape, ishift=None, oshift=None): @@ -197,7 +198,8 @@ def circshift(input, shifts, axes=None): if axes is None: axes = range(input.ndim) - assert len(axes) == len(shifts) + assert (len(axes) == len(shifts)) + xp = backend.get_array_module(input) for axis, shift in zip(axes, shifts): @@ -321,12 +323,13 @@ def triang(shape, dtype=float, device=backend.cpu_device): return window -def hanning(shape, dtype=float, device=backend.cpu_device): +def hanning(shape, dtype=float, symm=False, device=backend.cpu_device): """Create multi-dimensional hanning window. Args: shape (tuple of ints): Output shape. dtype (Dtype): Output data-type. + symm (boolean): Symmetric hanning window. device (Device): Output device. Returns: @@ -340,7 +343,13 @@ def hanning(shape, dtype=float, device=backend.cpu_device): window = xp.ones(shape, dtype=dtype) for n, i in enumerate(shape[::-1]): x = xp.arange(i, dtype=dtype) - w = 0.5 - 0.5 * xp.cos(2 * np.pi * x / max(1, (i - (i % 2)))) + + if symm is False: + den = max(1, (i - (i % 2))) + else: + den = max(1, i-1) + + w = 0.5 - 0.5 * xp.cos(2 * np.pi * x / den) window *= w.reshape([i] + [1] * n) return window diff --git a/tests/mri/test_app.py b/tests/mri/test_app.py index 197ece91..fc9b8279 100644 --- a/tests/mri/test_app.py +++ b/tests/mri/test_app.py @@ -4,6 +4,8 @@ import numpy.testing as npt import sigpy as sp + +from sigpy import util, linop, nlop from sigpy.mri import app, sim if __name__ == "__main__": @@ -161,3 +163,37 @@ def test_espirit_maps_eig(self): ).run() np.testing.assert_allclose(eig_val, 1, rtol=0.01, atol=0.01) + + def test_model_based_dti_recon(self): + + N_diffenc = 60 + B0 = np.zeros([1, 6]) + B1 = util.randn([N_diffenc, 6]) + B = np.concatenate((B0, B1)) * 1E3 + + img_shape = [15, 15] + D = util.randn([7, 1] + img_shape, dtype=float) / 1E6 + D = D + 0. * 1j + + E = nlop.Exponential(D.shape, B, rvc=True) + + mps_shape = [8] + img_shape + mps = sim.birdcage_maps(mps_shape) + S = linop.Multiply(E.oshape, mps) + + F = linop.FFT(S.oshape, axes=range(-2, 0)) + + A = F * S * E + + yn = A(D) + util.randn(A.oshape) * 1E-8 + + x = np.concatenate(( + np.ones([1, 1] + img_shape) * 1e-5, + np.zeros([6, 1] + img_shape)), dtype=yn.dtype) + + x = sp.app.NonLinearLeastSquares(A, yn, x=x, + max_iter=6, lamda=1E-3, redu=3, + gn_iter=6, inner_iter=100, + show_pbar=False).run() + + npt.assert_allclose(x, D, rtol=1E-5, atol=1E-5) diff --git a/tests/mri/test_dce.py b/tests/mri/test_dce.py new file mode 100644 index 00000000..65e9bdb2 --- /dev/null +++ b/tests/mri/test_dce.py @@ -0,0 +1,52 @@ +import unittest +import numpy as np +import numpy.testing as npt +from sigpy import app, util +from sigpy.mri import dce + +if __name__ == '__main__': + unittest.main() + +class TestDCE(unittest.TestCase): + + def test_dce(self): + + # %% DCE Sample Time + acq_time = 10 # total acquisition time (minute) + temp_res = 12 # frames per minute + + frames = acq_time * temp_res + 1 + + delay = 8 # baseline frames + + sample_time_0 = np.zeros([1, delay]) + sample_time_1 = (np.arange(1, frames-delay+1, 1) * (1/temp_res)).reshape((1, -1)) + + sample_time = np.hstack((sample_time_0, sample_time_1)) + + # %% DCE Parameters (K_trans, v_p)^T + K_trans = np.array([0.0402, 0.2505]).reshape((1, -1)) + v_p = np.array([0.05, 0.06]).reshape((1, -1)) + + param = [] + param.append(K_trans) + param.append(v_p) + + param = np.array(param) + param = param[:, None, None, :, :] + + # %% DCE Model + DCE = dce.DCE(param.shape, sample_time) + sig = DCE(param) + util.randn(DCE.oshape) * 1E-6 + + # %% NLLS Solver + x = np.ones_like(param, dtype=complex) * 0.1 + + x = app.NonLinearLeastSquares(DCE, sig, x=x, + lamda=1E-3, redu=3, + gn_iter=8, + inner_iter=100, + show_pbar=False, + verbose=True).run() + + npt.assert_allclose(x, param, rtol=1E-5, atol=1E-5) \ No newline at end of file diff --git a/tests/mri/test_epi.py b/tests/mri/test_epi.py new file mode 100644 index 00000000..f8309db7 --- /dev/null +++ b/tests/mri/test_epi.py @@ -0,0 +1,134 @@ +import unittest +import numpy as np +import numpy.testing as npt + +from sigpy import util +from sigpy.mri import epi + +if __name__ == '__main__': + unittest.main() + + +class TestEpi(unittest.TestCase): + + def test_get_B(self): + b_0 = np.zeros((1, 1)) + b_1 = np.ones((6, 1)) + + b = np.concatenate((b_0, b_1)) + g = np.array([[ 0 , 0 , 0 ], + [ 1 , 0 , 0 ], + [ 0 , 1 , 0 ], + [ 0 , 0 , 1 ], + [ 1 / 2**0.5, 1 / 2**0.5, 0 ], + [ 1 / 2**0.5, 0 , 1 / 2**0.5 ], + [ 0 , 1 / 2**0.5, 1 / 2**0.5 ]]) + + B1 = epi.get_B(b, g) + # xx, xy, yy,xz,yz,zz + B2 = np.array([[ 0 , 0 , 0 , 0, 0, 0 ], + [ 1 , 0 , 0 , 0, 0, 0 ], + [ 0 , 0 , 1 , 0, 0, 0 ], + [ 0 , 0 , 0 , 0, 0, 1 ], + [ 0.5, 1 , 0.5, 0, 0, 0 ], + [ 0.5, 0 , 0 , 1, 0, 0.5 ], + [ 0 , 0 , 0.5, 0, 1, 0.5 ]]) * -1. + + npt.assert_allclose(B1, B2, atol=1e-5, rtol=1e-5) + + def test_get_D(self): + ndif = 81 + nlin = 3 + ncol = 3 + + b_0 = np.zeros((1, 1)) + b_100 = np.ones((20, 1)) * 100 + b_500 = np.ones((20, 1)) * 500 + b_1000 = np.ones((40, 1)) * 1000 + + b = np.concatenate((b_0, b_100, b_500, b_1000)) + + g0 = np.zeros((1, 3)) + g1 = np.random.normal(size=(ndif-1, 3), loc=0, scale=0.212) + + g = np.concatenate((g0, g1)) + + B = epi.get_B(b, g) + + D = util.randn((6, nlin, ncol)) * 1e-3 + Dr = D.reshape(6, -1) + + S0 = np.abs(util.randn((nlin, ncol))) + sig = S0 * np.exp(-np.matmul(B, Dr).reshape(ndif, nlin, ncol)) + + Dinv = epi.get_D(B, sig, fit_only_tensor=True) + + npt.assert_allclose(Dinv, D, atol=1e-5, rtol=1e-5) + + def test_get_eig(self): + # xx, xy, yy, xz, yz, zz + D = np.array([[2.00, 0.00, 1.00, 0.00, 0.00, 0.50], + [1.75, -0.43, 1.25, 0.00, 0.00, 0.50], + [1.50, -0.50, 1.50, 0.00, 0.00, 0.50], + [1.00, 0.00, 2.00, 0.00, 0.00, 0.50]]) + + eigvals_exp = np.array([2.0, 1.0, 0.5]) + + for n in range(0, D.shape[0]): + Dn = np.reshape(D[n, :], (6, 1, 1)) + eigvals, _ = epi.get_eig(Dn) + npt.assert_allclose(np.squeeze(eigvals), eigvals_exp, + atol=3e-3, rtol=3e-3) + + # def test_comp_dipy(self): + # ndif = 115 + # nlin = 3 + # ncol = 3 + + # b_0 = np.zeros((1, 1)) + # b_1000 = np.ones((20, 1)) * 1000 + # b_2000 = np.ones((30, 1)) * 2000 + # b_3000 = np.ones((64, 1)) * 3000 + + # b = np.concatenate((b_0, b_1000, b_2000, b_3000)) + + # g0 = np.zeros((1, 3)) + # g1 = np.random.normal(size=(ndif-1, 3), loc=0, scale=0.212) + # gsum = np.sum(g1**2, axis=1)**0.5 + # g1 = g1 / gsum[:, np.newaxis] + + # g = np.concatenate((g0, g1)) + + # B = epi.get_B(b, g) + + # D = util.randn((6, nlin, ncol)) * 1e-3 + # Dr = D.reshape(6, -1) + + # S0 = np.abs(util.randn((nlin, ncol))) + # sig = S0 * np.exp(-np.matmul(B, Dr).reshape(ndif, nlin, ncol)) + + # Di = epi.get_D(B, sig) + # evals, evecs = epi.get_eig(Di, B=B) + # FA = epi.get_FA(evals) + + + # from dipy.core.gradients import gradient_table + + # gtab = gradient_table(b.flatten(), g, atol=0.1) + + # import dipy.reconst.dti as dti + # from dipy.reconst.dti import fractional_anisotropy, color_fa + # tenmodel = dti.TensorModel(gtab) + + # tenfit = tenmodel.fit(np.transpose(sig, axes=(1, 2, 0))) + + # evals_dp = np.transpose(tenfit.evals, axes=(2, 0, 1)) + # evecs_dp = np.transpose(tenfit.evecs, axes=(2, 3, 0, 1)) + # FA_dp = fractional_anisotropy(tenfit.evals) + + + # evecs0_dp = np.transpose(tenfit.evecs[..., 0], axes=(2, 0, 1)) + + # npt.assert_allclose(evals, evals_dp, atol=3e-3, rtol=3e-3) + # npt.assert_allclose(evecs, evecs_dp, atol=3e-3, rtol=3e-3) + # npt.assert_allclose(FA, FA_dp, atol=3e-3, rtol=3e-3) diff --git a/tests/mri/test_linop.py b/tests/mri/test_linop.py index bbdb7da4..485f6efc 100644 --- a/tests/mri/test_linop.py +++ b/tests/mri/test_linop.py @@ -47,7 +47,9 @@ def test_sense_model_batch(self): for coil_batch_size in [None, 1, 2, 3]: A = linop.Sense(mps, coil_batch_size=coil_batch_size) check_linop_adjoint(A, dtype=complex) - npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]), A * img) + + npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]), + A * img) def test_noncart_sense_model(self): img_shape = [16, 16] @@ -62,6 +64,7 @@ def test_noncart_sense_model(self): A = linop.Sense(mps, coord=coord) check_linop_adjoint(A, dtype=complex) + npt.assert_allclose( sp.fft(img * mps, axes=[-1, -2]).ravel(), (A * img).ravel(), @@ -112,6 +115,7 @@ def test_noncart_sense_model_batch(self): for coil_batch_size in [None, 1, 2, 3]: A = linop.Sense(mps, coord=coord, coil_batch_size=coil_batch_size) check_linop_adjoint(A, dtype=complex) + npt.assert_allclose( sp.fft(img * mps, axes=[-1, -2]).ravel(), (A * img).ravel(), @@ -138,3 +142,22 @@ def test_sense_model_with_comm(self): A.H(ksp[comm.rank :: comm.size]), np.sum(sp.ifft(ksp, axes=[-1, -2]) * mps.conjugate(), 0), ) + + def test_sense_subspace_model(self): + basis_shape = [80, 5] + img_shape = [5, 1, 16, 16] + mps_shape = [8, 16, 16] + + basis = sp.randn(basis_shape, dtype=complex) + img = sp.randn(img_shape, dtype=complex) + mps = sp.randn(mps_shape, dtype=complex) + + A = linop.Sense(mps, basis=basis) + + check_linop_adjoint(A, dtype=complex) + + full_img = basis @ np.reshape(img, (5, -1)) + full_img = np.reshape(full_img, [80] + img_shape[1:]) + + npt.assert_allclose(sp.fft(full_img * mps, axes=[-1, -2]), + A * img) diff --git a/tests/mri/test_nlop.py b/tests/mri/test_nlop.py new file mode 100644 index 00000000..dadd3666 --- /dev/null +++ b/tests/mri/test_nlop.py @@ -0,0 +1,92 @@ +import unittest +import numpy.testing as npt +from sigpy import backend, config, fourier, linop, util + +from sigpy.mri import nlop + +import sigpy as sp + +if __name__ == '__main__': + unittest.main() + +devices = [backend.cpu_device] +if config.cupy_enabled: + devices.append(backend.Device(0)) + + +class TestNlop(unittest.TestCase): + + def test_nlinv_model(self): + for device in devices: + xp = device.xp + + I = util.randn((1, 3, 3), dtype=complex, device=device) + C = util.randn((4, 3, 3), dtype=complex, device=device) + + A = nlop.Nlinv(I.shape, C.shape, W_coil=False) + + x = device.xp.ones(A.ishape, dtype=complex) + x[0, :, :] = I + x[1:, :, :] = C + + F = linop.FFT(C.shape, axes=(-2, -1)) + + # test forward + y1 = F(I*C) + y2 = A.forward(x) + + npt.assert_allclose(backend.to_device(y2), + backend.to_device(y1), + err_msg='forward operator!') + + # test derivative + dx = util.randn(x.shape, dtype=complex, device=device) + + dy1 = F(dx[0, :, :] * C + I * dx[1:, :, :]) + dy2 = A.derivative(x, dx) + + npt.assert_allclose(backend.to_device(dy2), + backend.to_device(dy1), + err_msg='derivative operator!') + + # test adjoint + dx1 = xp.zeros(x.shape, dtype=complex) + + dI1 = xp.sum(xp.conj(C) * F.H(dy1), axis=0) + dC1 = xp.conj(I) * F.H(dy1) + + dx1[0, :, :] = dI1 + dx1[1:, :, :] = dC1 + + dx2 = A.adjoint(x, dy2) + + npt.assert_allclose(backend.to_device(dx2), + backend.to_device(dx1), + err_msg='adjoint operator!') + + def test_diffusion_model(self): + for device in devices: + xp = device.xp + + # Exponential + tvec = util.randn((15, 6), dtype=float, device=device) + b0 = util.randn((1, 1, 3, 3), dtype=complex, device=device) + D = util.randn((6, 1, 3, 3), dtype=complex, device=device) + x = xp.concatenate((b0, D)) + + Dr = xp.reshape(D, (D.shape[0], -1)) + y1 = xp.exp(xp.matmul(tvec, Dr)) + y1 = xp.reshape(y1, (15, 1, 3, 3)) + y1 *= b0 + + # Sense + coils = util.randn((8, 3, 3), dtype=complex, device=device) + + y2 = fourier.fft(coils * y1, axes=(-2, -1)) + + A = nlop.Diffusion(x.shape, tvec, coils) + y = A(x) + + npt.assert_allclose(backend.to_device(y), + backend.to_device(y2), + err_msg='Diffusion model mismatch!') \ No newline at end of file diff --git a/tests/mri/test_sim.py b/tests/mri/test_sim.py new file mode 100644 index 00000000..ccc474fd --- /dev/null +++ b/tests/mri/test_sim.py @@ -0,0 +1,37 @@ +import unittest +import numpy as np +import numpy.testing as npt + +from sigpy.mri import sim + +if __name__ == '__main__': + unittest.main() + + +class TestSim(unittest.TestCase): + + def test_multi_gradient_echoes(self): + # echo time + TE0 = 1.70 + ESP = 1.52 + N_eco = 35 + + TE = (TE0 + np.arange(N_eco) * ESP) * 1E-3 + + # dictionary simulation + dict = sim.gradient_echoes(TE) + U = sim.get_subspace(dict, num_coeffs=22, prior_err=False) + + # simulate a tri-exponential signal + rho = [0.30, 0.30, 0.40] + T2 = [0.02, 0.01, 0.10] # second + B0 = [50, 100, -20] # Hz + + sig = np.zeros_like(TE, dtype=complex) + + for a, b, c in zip(rho, T2, B0): + sig += a * np.exp(-TE / b) * np.exp(1j*2*np.pi * c * TE) + + recon_sig = U @ U.T @ sig + + npt.assert_allclose(recon_sig, sig, atol=1e-3, rtol=1e-3) diff --git a/tests/mri/test_sms.py b/tests/mri/test_sms.py new file mode 100644 index 00000000..a226b99b --- /dev/null +++ b/tests/mri/test_sms.py @@ -0,0 +1,58 @@ +import unittest +import numpy as np +import numpy.testing as npt + +from sigpy.mri import sms + +if __name__ == '__main__': + unittest.main() + + +class TestSms(unittest.TestCase): + + def test_slice_order(self): + list_NS = [94, 114, 60, 159, 141, 117] + list_MB = [2, 3, 2, 3, 3, 3] + + for N_slices_uncollap, MB in zip(list_NS, list_MB): + + print('****** total slices: ' + str(N_slices_uncollap) + + ', multi-band: ' + str(MB) + ' ******') + + N_slices_collap = N_slices_uncollap // MB + + slice_mb_idx = [] + for s in range(N_slices_collap): + + slice_mb_idx += sms.map_acquire_to_ordered_slice_idx(s, N_slices_uncollap, MB, verbose=True) + + slice_mb_idx.sort() + + npt.assert_allclose(slice_mb_idx, + range(N_slices_uncollap), + err_msg = ['error in slice ordering for MB = ' + + str(MB) + ' slices = ' + + str(N_slices_uncollap).zfill(3)] + ) + + def test_reorder(self): + N_slices = 114 + N_band = 3 + + N_slices_collap = N_slices // N_band + + img_shape = [4, 4] + + I = np.zeros([N_slices_collap, N_band] + img_shape) + + for s in range(N_slices_collap): + slice_mb_idx = sms.map_acquire_to_ordered_slice_idx(s, N_slices, N_band, verbose=True) + + for b in range(N_band): + idx = slice_mb_idx[b] + I[s, b, :, :] = idx + + O = sms.reorder_slices_mbx(I, N_band, N_slices) + + for s in range(N_slices): + print('slice idx ' + str(s) + '; value ' + str(O[s, 0, 0])) diff --git a/tests/test_alg.py b/tests/test_alg.py index 4cde1d79..f6496d4b 100644 --- a/tests/test_alg.py +++ b/tests/test_alg.py @@ -27,7 +27,8 @@ def test_PowerMethod(self): A, x = self.Ax_setup(n) x_hat = np.random.random([n, 1]) alg_method = alg.PowerMethod(lambda x: A.T @ A @ x, x_hat) - while not alg_method.done(): + + while (not alg_method.done()): alg_method.update() s_numpy = np.linalg.svd(A, compute_uv=False)[0] @@ -66,7 +67,7 @@ def gradf(x): max_iter=1000, ) - while not alg_method.done(): + while (not alg_method.done()): alg_method.update() npt.assert_allclose(x_sigpy, x_numpy) @@ -79,7 +80,8 @@ def test_ConjugateGradient(self): alg_method = alg.ConjugateGradient( lambda x: A.T @ A @ x + lamda * x, A.T @ y, x, max_iter=1000 ) - while not alg_method.done(): + + while (not alg_method.done()): alg_method.update() npt.assert_allclose(x, x_numpy) @@ -107,7 +109,8 @@ def test_PrimalDualHybridGradient(self): sigma, max_iter=1000, ) - while not alg_method.done(): + + while (not alg_method.done()): alg_method.update() npt.assert_allclose(x, x_numpy) @@ -138,7 +141,8 @@ def h(x_z): alg_method = alg.AugmentedLagrangianMethod( minL, None, h, x_z, None, v, mu ) - while not alg_method.done(): + + while (not alg_method.done()): alg_method.update() x = x_z[:n] @@ -176,7 +180,8 @@ def f(x): alg_method = alg.NewtonsMethod( gradf, inv_hessf, x, beta=beta, f=f ) - while not alg_method.done(): + + while (not alg_method.done()): alg_method.update() npt.assert_allclose(x, x_numpy) @@ -195,7 +200,7 @@ def test_GerchbergSaxton(self): A, y, x0, max_iter=100, tol=10e-9, lamb=lamda ) - while not alg_method.done(): + while (not alg_method.done()): alg_method.update() phs = np.conj(x_numpy * alg_method.x / abs(x_numpy * alg_method.x)) diff --git a/tests/test_app.py b/tests/test_app.py index a32abbf8..38d993f4 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -3,7 +3,8 @@ import numpy as np import numpy.testing as npt -from sigpy import app, linop, prox, util +from sigpy import app, linop, util, prox, nlop + if __name__ == "__main__": unittest.main() @@ -131,3 +132,28 @@ def test_dual_precond_LinearLeastSquares(self): show_pbar=False, ).run() npt.assert_allclose(x_rec, x_lstsq, atol=1e-3) + + def test_NonlinearLeastSquares(self): + N_diffenc = 60 + B0 = np.zeros([1, 6]) + B1 = util.randn([N_diffenc, 6]) + B = np.concatenate((B0, B1)) * 1E3 + + img_shape = [1, 15, 15] + D = util.randn([7] + img_shape, dtype=float) / 1E6 + D = D + 0. * 1j + + E = nlop.Exponential(D.shape, B, rvc=True) + + y = E(D) + util.randn(E.oshape) * 1E-8 + + x = np.concatenate(( + np.ones([1] + img_shape) * 1e-5, + np.zeros([6] + img_shape)), dtype=y.dtype) + + x = app.NonLinearLeastSquares(E, abs(y), x=x, + max_iter=6, lamda=1E-3, redu=3, + gn_iter=6, inner_iter=100, + show_pbar=False).run() + + npt.assert_allclose(x, D, rtol=1E-5, atol=1E-5) diff --git a/tests/test_coord.py b/tests/test_coord.py new file mode 100644 index 00000000..30fc9cda --- /dev/null +++ b/tests/test_coord.py @@ -0,0 +1,22 @@ +import unittest +import numpy as np +import numpy.testing as npt +from sigpy import util, coord + +if __name__ == '__main__': + unittest.main() + +class TestCoord(unittest.TestCase): + + def test_normal(self): + + x = util.randn((10, 1)) + y = util.randn((10, 1)) + z = util.randn((10, 1)) + + r, theta, phi = coord.cartes_to_spheri(x, y, z) + xn, yn, zn = coord.spheri_to_cartes(r, theta, phi) + + npt.assert_allclose(x, xn) + npt.assert_allclose(y, yn) + npt.assert_allclose(z, zn) \ No newline at end of file diff --git a/tests/test_fourier.py b/tests/test_fourier.py index da69ed12..de601e1d 100644 --- a/tests/test_fourier.py +++ b/tests/test_fourier.py @@ -12,6 +12,7 @@ class TestFourier(unittest.TestCase): def test_fft(self): input = np.array([0, 1, 0], dtype=complex) + npt.assert_allclose( fourier.fft(input), np.ones(3) / 3**0.5, atol=1e-5 ) @@ -29,6 +30,7 @@ def test_fft(self): ) input = np.array([0, 1, 0], dtype=complex) + npt.assert_allclose( fourier.fft(input, oshape=[5]), np.ones(5) / 5**0.5, atol=1e-5 ) @@ -42,6 +44,7 @@ def test_fft_dtype(self): def test_ifft(self): input = np.array([0, 1, 0], dtype=complex) + npt.assert_allclose( fourier.ifft(input), np.ones(3) / 3**0.5, atol=1e-5 ) @@ -59,6 +62,7 @@ def test_ifft(self): ) input = np.array([0, 1, 0], dtype=complex) + npt.assert_allclose( fourier.ifft(input, oshape=[5]), np.ones(5) / 5**0.5, atol=1e-5 ) diff --git a/tests/test_linop.py b/tests/test_linop.py index fa28034b..bc922cdf 100644 --- a/tests/test_linop.py +++ b/tests/test_linop.py @@ -16,6 +16,7 @@ class TestLinop(unittest.TestCase): + def check_linop_unitary(self, A, device=backend.cpu_device, dtype=float): device = backend.Device(device) x = util.randn(A.ishape, dtype=dtype, device=device) @@ -266,6 +267,7 @@ def test_Multiply(self): x = np.array([1.0, 2.0], complex) y = np.array([[1.0, 4.0], [3.0, 8.0]], complex) + npt.assert_allclose(A * x, y) def test_Resize(self): @@ -576,3 +578,21 @@ def test_Slice(self): self.check_linop_adjoint(A) self.check_linop_normal(A) self.check_linop_pickleable(A) + + def test_RealValueConstraint(self): + for device in devices: + xp = device.xp + + data = util.randn([6, 6], dtype=complex, device=device) + + RVC = linop.RealValueConstraint(data.shape) + + y1 = RVC(data) + y2 = xp.real(data).astype(complex) + + npt.assert_allclose(y1, y2) + + # self.check_linop_linear(A) + # self.check_linop_adjoint(A) + # self.check_linop_normal(A) + # self.check_linop_pickleable(A) \ No newline at end of file diff --git a/tests/test_nlls.py b/tests/test_nlls.py new file mode 100644 index 00000000..54f08742 --- /dev/null +++ b/tests/test_nlls.py @@ -0,0 +1,55 @@ +import unittest +import numpy.testing as npt +from sigpy import backend, config, util, linop, prox + +from sigpy.mri import nlop, epi, nlrecon + +if __name__ == '__main__': + unittest.main() + +devices = [backend.cpu_device] +if config.cupy_enabled: + devices.append(backend.Device(0)) + +class TestNlls(unittest.TestCase): + + def test_diff_nlls(self): + for device in devices: + xp = device.xp + + b = xp.load('diff_b.npy') + g = xp.load('diff_g.npy') + + for model in ['dti', 'dki']: + if model == 'dti': + B = epi.get_B(b, g) + elif model == 'dki': + B = epi.get_B2(b, g) + + num_params = B.shape[1] + + D = util.randn((num_params, 16, 16), + dtype=float, + device=device) / 1E8 + + D = D + 0. * 1j # real numbers + + coils = util.randn((8, 1, 16, 16), + dtype=complex, + device=device) + + F = nlop.Diffusion(D.shape, B, coil=coils) + y = F(D) + util.randn(F.oshape) * 1e-5 + + x = nlrecon.kinv(y, + D.shape, coils.shape, + coil=coils, + outer_iter=8, redu=3, + inner_iter=100, + model='Diffusion', + sample_time=B + ).run() + + npt.assert_allclose(x, D, + rtol=1e-5, atol=1e-7, + err_msg='Diffusion Nlls failed!') diff --git a/tests/test_nlop.py b/tests/test_nlop.py new file mode 100644 index 00000000..3ee7682c --- /dev/null +++ b/tests/test_nlop.py @@ -0,0 +1,155 @@ +import unittest +import numpy as np +import numpy.testing as npt +from sigpy import backend, config, nlop, util, linop + +import torch +from torch.autograd import Variable + +if __name__ == '__main__': + unittest.main() + +devices = [backend.cpu_device] +if config.cupy_enabled: + devices.append(backend.Device(0)) + + +class TestNlop(unittest.TestCase): + + def check_nlop_derivative(self, A, + device=backend.cpu_device, + dtype=float): + device = backend.Device(device) + + scale = 1e-8 + x = util.randn(A.ishape, dtype=dtype, device=device) + h = util.randn(A.ishape, dtype=dtype, device=device) + + with device: + dy1 = (A.forward(x + scale * h) - A.forward(x)) / scale + dy2 = A.derivative(x, h) + + npt.assert_allclose(backend.to_device(dy1), + backend.to_device(dy2), + atol=1e-5, + err_msg=A.repr_str + ' derivative operator!') + + def check_nlop_adjoint(self, A, device=backend.cpu_device, dtype=float): + device = backend.Device(device) + x = util.randn(A.ishape, dtype=dtype, device=device) + + dx = util.randn(A.ishape, dtype=dtype, device=device) + dy = util.randn(A.oshape, dtype=dtype, device=device) + + xp = device.xp + with device: + lhs = xp.vdot(A.derivative(x, dx), dy) + rhs = xp.vdot(dx, A.adjoint(x, dy)) + + npt.assert_allclose(backend.to_device(lhs), + backend.to_device(rhs), + atol=1e-5, + err_msg=A.repr_str + ' adjoint operator!') + + def test_Exponential(self): + for device in devices: + xp = device.xp + list_bias = [True, False, True, False] + list_coef = [1, 1, 6, 6] + for bias, num_coef in zip(list_bias, list_coef): + num_param = num_coef + 2 if bias is True else num_coef + 1 + I = util.randn((num_param, 1, 3, 3), + dtype=complex, device=device) + image_shape = I.shape[1:] + + num_time = 7 + tvec = util.randn((num_time, num_coef), + dtype=float, device=device) + + A = nlop.Exponential(I.shape, tvec, bias=bias) + + # test forward + y2 = A.forward(I) + y1 = xp.zeros_like(y2) + offset = 1 if bias else 0 + a = I[offset, ...] + R = I[offset+1:, ...] + Rr = xp.reshape(R, (R.shape[0], -1)) + y1 = xp.exp(xp.matmul(tvec, Rr)) + y1 = a * xp.reshape(y1, [num_time] + list(image_shape)) + if bias is True: + y1 += I[0, ...] + + npt.assert_allclose(backend.to_device(y2), + backend.to_device(y1), + err_msg=A.repr_str + ' forward operator!') + + # test derivative + self.check_nlop_derivative(A, device=device, dtype=I.dtype) + + # test adjoint + self.check_nlop_adjoint(A, device=device, dtype=I.dtype) + + def test_Exponential_torch(self): + b0 = util.randn((1, 1, 3, 3), dtype=float) + D = util.randn((6, 1, 3, 3), dtype=float) + I = np.concatenate((b0, D)) + + tvec = util.randn((16, 6), dtype=float) + + A = nlop.Exponential(I.shape, tvec, bias=False) + + dy = np.ones(A.oshape, dtype=float) + dx = A.adjoint(I, dy) + + # PyTorch + b0_torch = Variable(torch.tensor(b0), requires_grad=True) + D_torch = Variable(torch.tensor(D), requires_grad=True) + t_torch = Variable(torch.tensor(tvec), requires_grad=False) + + Dr_torch = torch.reshape(D_torch, (D_torch.shape[0], -1)) + y2 = torch.exp(torch.matmul(t_torch, Dr_torch)) + y2 = torch.reshape(y2, A.oshape) + y2 = (y2 * b0_torch).sum() + + y2.backward() + + db0 = b0_torch.grad.numpy() + dD = D_torch.grad.numpy() + + dx_torch = np.concatenate((db0, dD)) + + npt.assert_allclose(dx, dx_torch, err_msg='Derivative mismatch!') + + def test_Compose(self): + tvec = util.randn((15, 6), dtype=float) + x = util.randn((7, 1, 3, 3), dtype=complex) + + E = nlop.Exponential(x.shape, tvec, bias=False) + + smat = util.randn((8, 3, 3), dtype=complex) + S = linop.Multiply(E.oshape, smat) + + A = S * E + + y1 = smat * E(x) + y2 = A(x) + + npt.assert_allclose(y2, y1, + err_msg=A.repr_str + ' forward operator!') + + dx = util.randn(x.shape, dtype=x.dtype) + + dy1 = smat * E.derivative(x, dx) + dy2 = A.derivative(x, dx) + + npt.assert_allclose(dy2, dy1, + err_msg=A.repr_str + ' derivative operator!') + + dy = util.randn(A.oshape, dtype=x.dtype) + + dx1 = E.adjoint(x, S.H * dy) + dx2 = A.adjoint(x, dy) + + npt.assert_allclose(dx2, dx1, + err_msg=A.repr_str + ' adjoint operator!') diff --git a/tests/test_prox.py b/tests/test_prox.py index 8273c046..166549e4 100644 --- a/tests/test_prox.py +++ b/tests/test_prox.py @@ -85,3 +85,38 @@ def test_BoxConstraint(self): x = np.array([-2, -1, 0, 1, 2]) y = P(None, x) npt.assert_allclose(y, [-1, 0, 0, 0, 1]) + + def test_Conj(self): + shape = [3, 3] + x = util.randn(shape, dtype=float) + + F = linop.FiniteDifference(shape, axes=(-2, -1)) + proxg = prox.L1Reg(F.oshape, 1.) + proxgc = prox.Conj(proxg) + + O = F(x) + y1 = proxgc(1., O) + + x1 = O[0, :, :] + x2 = O[1, :, :] + d1 = np.maximum(1., abs(x1)) + x1n = np.divide(x1, d1) + d2 = np.maximum(1., abs(x2)) + x2n = np.divide(x2, d2) + y2 = np.stack((x1n, x2n)) + + npt.assert_allclose(y1, y2) + + def test_LLRL1Reg(self): + shape = [15, 48, 32] + x = util.randn(shape, dtype=complex) + L = prox.LLRL1Reg(shape, 1) + y = L(0., x) + npt.assert_allclose(y, x) + + def test_SLRMCReg(self): + shape = [15, 48, 32] + x = util.randn(shape, dtype=complex) + L = prox.SLRMCReg(shape, 1, blk_shape=(8, 8)) + y = L(1., x) + npt.assert_allclose(y, x) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 0b624998..485e7ec5 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -24,6 +24,7 @@ def test_to_pytorch(self): array = xp.array([1, 2, 3], dtype=dtype) tensor = pytorch.to_pytorch(array) array[0] = 0 + torch.testing.assert_allclose( tensor, torch.tensor( @@ -41,6 +42,7 @@ def test_to_pytorch_complex(self): array = xp.array([1 + 1j, 2 + 2j, 3 + 3j], dtype=dtype) tensor = pytorch.to_pytorch(array) array[0] = 0 + torch.testing.assert_allclose( tensor, torch.tensor( @@ -120,6 +122,7 @@ def test_to_pytorch_function_complex(self): A, input_iscomplex=True, output_iscomplex=True ).apply x_torch = pytorch.to_pytorch(x) + npt.assert_allclose( f(x_torch).detach().numpy().ravel(), A(x).view(float) ) @@ -128,6 +131,7 @@ def test_to_pytorch_function_complex(self): y_torch = pytorch.to_pytorch(y) loss = (f(x_torch) - y_torch).pow(2).sum() / 2 loss.backward() + npt.assert_allclose( x_torch.grad.detach().numpy().ravel(), A.H(A(x) - y).view(float),