diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9a2d18bb..358f569e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -65,13 +65,9 @@ jobs: - name: Install test dependencies run: | pip install .[test] - - name: Lint with flake8 + - name: Lint with ruff run: | - pip install flake8 - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + ruff check . - name: Check types with MyPy run: | mypy src/mpol --pretty diff --git a/.gitignore b/.gitignore index 036a4156..0e72b548 100644 --- a/.gitignore +++ b/.gitignore @@ -144,4 +144,6 @@ plotsdir runs # hatch-generated version file -src/mpol/mpol_version.py \ No newline at end of file +src/mpol/mpol_version.py + +.ruff_cache \ No newline at end of file diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 9541b359..4f62cd04 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -10,3 +10,4 @@ Contributors * Hannah Grzybowski, `@hgrzy` * Mary Ogborn * Tyler Quinn, `@trq5014` +* Kristin Hopley \ No newline at end of file diff --git a/docs/_static/baselines/src/print_conversions.py b/docs/_static/baselines/src/print_conversions.py index 6c66a697..9599ee95 100644 --- a/docs/_static/baselines/src/print_conversions.py +++ b/docs/_static/baselines/src/print_conversions.py @@ -1,3 +1,8 @@ +import csv + +import numpy as np +from mpol.constants import c_ms + import argparse parser = argparse.ArgumentParser( @@ -6,11 +11,6 @@ parser.add_argument("outfile", help="Destination to save CSV table.") args = parser.parse_args() -import csv - -import numpy as np - -from mpol.constants import c_ms header = ["baseline", "100 GHz (Band 3)", "230 GHz (Band 6)", "340 GHz (Band 7)"] @@ -20,18 +20,18 @@ def format_baseline(baseline_m): if baseline_m < 1e3: - return "{:.0f} m".format(baseline_m) + return f"{baseline_m:.0f} m" elif baseline_m < 1e6: - return "{:.0f} km".format(baseline_m * 1e-3) + return f"{baseline_m * 1e-3:.0f} km" def format_lambda(lam): if lam < 1e3: - return "{:.0f}".format(lam) + " :math:`\lambda`" + return f"{lam:.0f}" + r" :math:`\lambda`" elif lam < 1e6: - return "{:.0f}".format(lam * 1e-3) + " :math:`\mathrm{k}\lambda`" + return f"{lam * 1e-3:.0f}" + r" :math:`\mathrm{k}\lambda`" else: - return "{:.0f}".format(lam * 1e-6) + " :math:`\mathrm{M}\lambda`" + return f"{lam * 1e-6:.0f}" + r" :math:`\mathrm{M}\lambda`" data = [] diff --git a/docs/_static/fftshift/src/plot.py b/docs/_static/fftshift/src/plot.py index f049cd8c..2aac186e 100644 --- a/docs/_static/fftshift/src/plot.py +++ b/docs/_static/fftshift/src/plot.py @@ -1,9 +1,3 @@ -import argparse - -parser = argparse.ArgumentParser(description="Create the fftshift plot") -parser.add_argument("outfile", help="Destination to save plot.") -args = parser.parse_args() - import matplotlib.pyplot as plt import numpy as np from astropy.io import fits @@ -11,9 +5,15 @@ from matplotlib import patches from matplotlib.colors import LogNorm from matplotlib.gridspec import GridSpec - from mpol import coordinates +import argparse + +parser = argparse.ArgumentParser(description="Create the fftshift plot") +parser.add_argument("outfile", help="Destination to save plot.") +args = parser.parse_args() + + fname = download_file( "https://zenodo.org/record/4711811/files/logo_cont.fits", cache=True, diff --git a/docs/changelog.md b/docs/changelog.md index 2600410e..970443b7 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -3,7 +3,11 @@ # Changelog ## v0.3.0 - +- removed explicit type declarations in base MPoL modules. Previously, core representations were set to be in `float64` or `complex128`. Now, core MPoL representations (e.g., {class}`mpol.images.BaseCube`) will follow the [default tensor type](https://pytorch.org/docs/stable/generated/torch.set_default_tensor_type.html), which is commonly `torch.float32`. If you want your model to run fully in `float32` or `complex64`, then be sure that your data is also in these formats, since otherwise PyTorch will promote downstream tensors as needed. Fully `float32` or `complex64` models should be able to run on Apple MPS [#254](https://github.com/MPoL-dev/MPoL/issues/254) +- added {meth}`mpol.utils.convolve_packed_cube` method to convolve a 3D packed image cube with a 2D Gaussian. You can specify major axis, minor axis, and rotation angle. +- added {meth}`mpol.utils.uv_gaussian_taper` to calculate a Gaussian tapering window in the visibility plane. +- added the `vis_ext_Mlam` instance attribute to {class}`mpol.coordinates.GridCoords` for convenience plotting of visibility grids with axes labels in units of M$\lambda$. +- Updated [MPoL-dev/examples](https://github.com/MPoL-dev/examples) with Stochastic Gradient Descent Example. - Standardized nomenclature of {class}`mpol.coordinates.GridCoords` and {class}`mpol.fourier.FourierCube` to use `sky_cube` for a normal image and `ground_cube` for a normal visibility cube (rather than `sky_` for visibility quantities). Routines use `packed_cube` instead of `cube` internally to be clear when packed format is preferred. - Modified {class}`mpol.coordinates.GridCoords` object to use cached properties [#187](https://github.com/MPoL-dev/MPoL/pull/187). - Changed the base spatial frequency unit from k$\lambda$ to $\lambda$, addressing [#223](https://github.com/MPoL-dev/MPoL/issues/223). This will affect most users data-reading routines! diff --git a/docs/conf.py b/docs/conf.py index b9d2f9d7..9666fee1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,3 @@ -import os # -- Project information ----------------------------------------------------- from pkg_resources import DistributionNotFound, get_distribution @@ -46,7 +45,7 @@ autodoc_mock_imports = ["torch", "torchvision"] autodoc_member_order = "bysource" # https://github.com/sphinx-doc/sphinx/issues/9709 -# bug that if we set this here, we can't list individual members in the +# bug that if we set this here, we can't list individual members in the # actual API doc # autodoc_default_options = {"members": None} diff --git a/pyproject.toml b/pyproject.toml index 22146fc5..3f04b169 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,6 @@ dev = [ "mypy", "frank>=1.2.1", "sphinx>=7.2.0", - "sphinx-autodoc2", "jupytext", "ipython!=8.7.0", # broken version for syntax higlight https://github.com/spatialaudio/nbsphinx/issues/687 "nbsphinx", @@ -51,7 +50,8 @@ dev = [ "asdf", "pyro-ppl", "arviz[all]", - "visread>=0.0.4" + "visread>=0.0.4", + "ruff" ] test = [ "pytest", @@ -62,6 +62,7 @@ test = [ "mypy", "visread>=0.0.4", "frank>=1.2.1", + "ruff" ] [project.urls] @@ -105,4 +106,18 @@ module = [ "MPoL.precomposed", "MPoL.utils" ] -disallow_untyped_defs = true \ No newline at end of file +disallow_untyped_defs = true + +[tool.ruff] +target-version = "py310" +line-length = 88 +# will enable after sorting module locations +# select = ["F", "I", "E", "W", "YTT", "B", "Q", "PLE", "PLR", "PLW", "UP"] +lint.ignore = [ + "E741", # Allow ambiguous variable names + "PLR0911", # Allow many return statements + "PLR0913", # Allow many arguments to functions + "PLR0915", # Allow many statements + "PLR2004", # Allow magic numbers in comparisons +] +exclude = [] \ No newline at end of file diff --git a/src/mpol/__init__.py b/src/mpol/__init__.py index 801c219f..5d74b25a 100644 --- a/src/mpol/__init__.py +++ b/src/mpol/__init__.py @@ -1,2 +1 @@ -from mpol.mpol_version import __version__ zenodo_record = 10064221 diff --git a/src/mpol/coordinates.py b/src/mpol/coordinates.py index 66a6c4ec..04b3af01 100644 --- a/src/mpol/coordinates.py +++ b/src/mpol/coordinates.py @@ -1,6 +1,6 @@ from __future__ import annotations -from functools import cached_property +from functools import cached_property from typing import Any import numpy as np @@ -10,7 +10,7 @@ import mpol.constants as const from mpol.exceptions import CellSizeError -from mpol.utils import get_max_spatial_freq, get_maximum_cell_size +from mpol.utils import get_maximum_cell_size class GridCoords: @@ -79,6 +79,7 @@ class GridCoords: :ivar vis_ext: length-4 list of (left, right, bottom, top) expected by routines like ``matplotlib.pyplot.imshow`` in the ``extent`` parameter assuming ``origin='lower'``. Units of [:math:`\lambda`] + :ivar vis_ext_Mlam: like vis_ext, but in units of [:math:`\mathrm{M}\lambda`]. """ def __init__(self, cell_size: float, npix: int): @@ -205,16 +206,18 @@ def vis_ext(self) -> list[float]: self.u_bin_max, self.v_bin_min, self.v_bin_max, - ] # [kλ] + ] # [λ] + + @property + def vis_ext_Mlam(self) -> list[float]: + return [1e-6 * edge for edge in self.vis_ext] - # -------------------------------------------------------------------------- - # Non-identical u & v properties - # -------------------------------------------------------------------------- @cached_property def ground_u_centers_2D(self) -> npt.NDArray[np.floating[Any]]: # only useful for plotting a sky_vis # uu increasing, no fftshift - # tile replicates the 1D u_centers array to a 2D array the size of the full UV grid + # tile replicates the 1D u_centers array to a 2D array the size of the full + # UV grid return np.tile(self.u_centers, (self.npix_u, 1)) @cached_property @@ -304,10 +307,10 @@ def check_data_fit( Parameters ---------- - uu : :class:`torch.Tensor` of `torch.double` + uu : :class:`torch.Tensor` u spatial frequency coordinates. Units of [:math:`\lambda`] - vv : :class:`torch.Tensor` of `torch.double` + vv : :class:`torch.Tensor` v spatial frequency coordinates. Units of [:math:`\lambda`] @@ -354,6 +357,6 @@ def __eq__(self, other: Any) -> bool: # don't attempt to compare against different types return NotImplemented - # GridCoords objects are considered equal if they have the same cell_size and npix, since - # all other attributes are derived from these two core properties. + # GridCoords objects are considered equal if they have the same cell_size and + # npix, since all other attributes are derived from these two core properties. return bool(self.cell_size == other.cell_size and self.npix == other.npix) diff --git a/src/mpol/crossval.py b/src/mpol/crossval.py index fa02ab9c..e473e1bc 100644 --- a/src/mpol/crossval.py +++ b/src/mpol/crossval.py @@ -2,7 +2,6 @@ import copy import logging -from collections import defaultdict from typing import Any import numpy as np @@ -11,11 +10,9 @@ from numpy.typing import NDArray from mpol.datasets import Dartboard, GriddedDataset -from mpol.precomposed import GriddedNet + # from mpol.training import TrainTest, train_to_dirty_image # from mpol.training import TrainTest, train_to_dirty_image -from mpol.plot import split_diagnostics_fig -from mpol.utils import loglinspace # class CrossValidate: @@ -59,7 +56,8 @@ # Number of k-folds to use in cross-validation # split_method : str, default='dartboard' # Method to split full dataset into train/test subsets -# dartboard_q_edges, dartboard_phi_edges : list of float, default=None, unit=[klambda] +# dartboard_q_edges, dartboard_phi_edges : list of float, default=None, +# unit=[klambda] # Radial and azimuthal bin edges of the cells used to split the dataset # if `split_method`==`dartboard` (see `datasets.Dartboard`) # split_diag_fig : bool, default=False diff --git a/src/mpol/datasets.py b/src/mpol/datasets.py index 7444015e..b9864864 100644 --- a/src/mpol/datasets.py +++ b/src/mpol/datasets.py @@ -7,9 +7,8 @@ from numpy import floating, integer from numpy.typing import ArrayLike, NDArray -from mpol.coordinates import GridCoords - from mpol import utils +from mpol.coordinates import GridCoords class GriddedDataset(torch.nn.Module): @@ -20,7 +19,7 @@ class GriddedDataset(torch.nn.Module): If providing this, cannot provide ``cell_size`` or ``npix``. vis_gridded : :class:`torch.Tensor` of :class:`torch.complex128` the gridded visibility data stored in a "packed" format (pre-shifted for fft) - weight_gridded : :class:`torch.Tensor` of :class:`torch.double` + weight_gridded : :class:`torch.Tensor` the weights corresponding to the gridded visibility data, also in a packed format mask : :class:`torch.Tensor` of :class:`torch.bool` diff --git a/src/mpol/fourier.py b/src/mpol/fourier.py index c45c272b..82e63795 100644 --- a/src/mpol/fourier.py +++ b/src/mpol/fourier.py @@ -1,18 +1,14 @@ from __future__ import annotations -from typing import Any - import numpy as np import torch import torch.fft # to avoid conflicts with old torch.fft *function* import torchkbnufft -from numpy.typing import NDArray from torch import nn -from mpol.exceptions import DimensionMismatchError - from mpol import utils from mpol.coordinates import GridCoords +from mpol.exceptions import DimensionMismatchError class FourierCube(nn.Module): @@ -44,29 +40,29 @@ def __init__(self, coords: GridCoords, persistent_vis: bool = False): self.register_buffer("vis", None, persistent=persistent_vis) self.vis: torch.Tensor - def forward(self, cube: torch.Tensor) -> torch.Tensor: + def forward(self, packed_cube: torch.Tensor) -> torch.Tensor: """ Perform the FFT of the image cube on each channel. Parameters ---------- - cube : :class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)`` + cube : :class:`torch.Tensor` of shape ``(nchan, npix, npix)`` A 'packed' tensor. For example, an image cube from :meth:`mpol.images.ImageCube.forward` Returns ------- - :class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)``. + :class:`torch.Tensor` of shape ``(nchan, npix, npix)``. The FFT of the image cube, in packed format. """ # make sure the cube is 3D - assert cube.dim() == 3, "cube must be 3D" + assert packed_cube.dim() == 3, "cube must be 3D" # the self.cell_size prefactor (in arcsec) is to obtain the correct output units # since it needs to correct for the spacing of the input grid. # See MPoL documentation and/or TMS Eqn A8.18 for more information. - self.vis = self.coords.cell_size**2 * torch.fft.fftn(cube, dim=(1, 2)) + self.vis = self.coords.cell_size**2 * torch.fft.fftn(packed_cube, dim=(1, 2)) return self.vis @@ -93,7 +89,7 @@ def ground_amp(self) -> torch.Tensor: Returns ------- - :class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)`` + :class:`torch.Tensor` of shape ``(nchan, npix, npix)`` amplitude cube in 'ground' format. """ return torch.abs(self.ground_vis) @@ -106,7 +102,7 @@ def ground_phase(self) -> torch.Tensor: Returns ------- - :class:`torch.Tensor` of :class:`torch.double` of shape ``(nchan, npix, npix)`` + :class:`torch.Tensor` of shape ``(nchan, npix, npix)`` phase cube in 'ground' format (:math:`[-\pi,\pi)`). """ return torch.angle(self.ground_vis) @@ -279,14 +275,14 @@ def forward( Parameters ---------- - packed_cube : :class:`torch.Tensor` of :class:`torch.double` + packed_cube : :class:`torch.Tensor` shape ``(nchan, npix, npix)``). The cube should be a "prepacked" image cube, for example, from :meth:`mpol.images.ImageCube.forward` - uu : :class:`torch.Tensor` of :class:`torch.double` + uu : :class:`torch.Tensor` 2D array of the u (East-West) spatial frequency coordinate [:math:`\lambda`] of shape ``(nchan, npix)`` - vv : :class:`torch.Tensor` of :class:`torch.double` + vv : :class:`torch.Tensor` 2D array of the v (North-South) spatial frequency coordinate [:math:`\lambda`] (must be the same shape as uu) sparse_matrices : bool @@ -372,7 +368,7 @@ def forward( shifted = torch.fft.fftshift(packed_cube, dim=(1, 2)) # convert the cube to a complex type, since this is required by TorchKbNufft - complexed = shifted.type(torch.complex128) + complexed = shifted + 0j k_traj = self._assemble_ktraj(uu, vv) @@ -439,10 +435,10 @@ def forward( class NuFFTCached(NuFFT): r""" - This layer is similar to the :class:`mpol.fourier.NuFFT`, but provides extra + This layer is similar to the :class:`mpol.fourier.NuFFT`, but provides extra functionality to cache the sparse matrices for a specific set of ``uu`` and ``vv`` - points specified at initialization. - + points specified at initialization. + For repeated evaluations of this layer (as might exist within an optimization loop), ``sparse_matrices=True`` is likely to be the more accurate and faster choice. If ``sparse_matrices=False``, this routine will use the default table-based @@ -502,7 +498,7 @@ def __init__( self.real_interp_mat: torch.Tensor self.imag_interp_mat: torch.Tensor - def forward(self, cube): + def forward(self, packed_cube): r""" Perform the FFT of the image cube for each channel and interpolate to the ``uu`` and ``vv`` points set at layer initialization. This call should @@ -510,9 +506,10 @@ def forward(self, cube): ``uu`` and ``vv`` points. Args: - cube (torch.double tensor): of shape ``(nchan, npix, npix)``). The cube - should be a "prepacked" image cube, for example, from - :meth:`mpol.images.ImageCube.forward` + packed_cube : :class:`torch.Tensor` + shape ``(nchan, npix, npix)``). The cube + should be a "prepacked" image cube, for example, + from :meth:`mpol.images.ImageCube.forward` Returns: torch.complex tensor: of shape ``(nchan, nvis)``, Fourier samples evaluated @@ -521,17 +518,18 @@ def forward(self, cube): # make sure that the nchan assumptions for the ImageCube and the NuFFT # setup are the same - if cube.size(0) != self.nchan: + if packed_cube.size(0) != self.nchan: raise DimensionMismatchError( - f"nchan of ImageCube ({cube.size(0)}) is different than that used to initialize NuFFT layer ({self.nchan})" + f"nchan of ImageCube ({packed_cube.size(0)}) is different than that used to initialize NuFFT layer ({self.nchan})" ) # "unpack" the cube, but leave it flipped # NuFFT routine expects a "normal" cube, not an fftshifted one - shifted = torch.fft.fftshift(cube, dim=(1, 2)) + shifted = torch.fft.fftshift(packed_cube, dim=(1, 2)) # convert the cube to a complex type, since this is required by TorchKbNufft - complexed = shifted.type(torch.complex128) + complexed = shifted + 0j + # Consider how the similarity of the spatial frequency samples should be # treated. We already took care of this on the k_traj side, since we set diff --git a/src/mpol/geometry.py b/src/mpol/geometry.py index 27241c63..98d41013 100644 --- a/src/mpol/geometry.py +++ b/src/mpol/geometry.py @@ -16,8 +16,8 @@ def flat_to_observer( (X,Y,Z). It is assumed that the +Z axis points *towards* the observer. It is assumed that the - model is flat in the (x,y) frame (like a flat disk), and so the operations - involving ``z`` are neglected. But the model lives in 3D Cartesian space. + model is flat in the (x,y) frame (like a flat disk), and so the operations + involving ``z`` are neglected. But the model lives in 3D Cartesian space. In order, @@ -28,11 +28,18 @@ def flat_to_observer( Inspired by `exoplanet/keplerian.py `_ + Note that the (x,y) here are *not* the same as the `x_centers_2D` or `y_centers_2D` + attached to the :class:`mpol.coordinates.GridCoords` object. The (x,y) referred to + here are the 'perifocal frame' of the orbit, whereas the (X,Y,Z) are the sky or + observer frame. Typically, the sky observer frame is oriented such that X is North + (pointing up) and Y is East (pointing left). For more detail, see the `exoplanet + docs `_ or `Murray and Correia `_. + Parameters ---------- - x : :class:`torch.Tensor` of :class:`torch.double` + x : :class:`torch.Tensor` A tensor representing the x coordinate in the plane of the orbit. - y : :class:`torch.Tensor` of :class:`torch.double` + y : :class:`torch.Tensor` A tensor representing the y coordinate in the plane of the orbit. omega : float Argument of periastron [radians]. Default 0.0. @@ -43,7 +50,7 @@ def flat_to_observer( Returns ------- - 2-tuple of :class:`torch.Tensor` of :class:`torch.double` + 2-tuple of :class:`torch.Tensor` Two tensors representing ``(X, Y)`` in the observer frame. """ # Rotation matrices result in a *clockwise* rotation of the axes, @@ -98,12 +105,19 @@ def observer_to_flat( Inspired by `exoplanet/keplerian.py `_ + Note that the (x,y) here are *not* the same as the `x_centers_2D` or `y_centers_2D` + attached to the :class:`mpol.coordinates.GridCoords` object. The (x,y) referred to + here are the 'perifocal frame' of the orbit, whereas the (X,Y,Z) are the sky or + observer frame. Typically, the sky observer frame is oriented such that X is North + (pointing up) and Y is East (pointing left). For more detail, see the `exoplanet + docs `_ or `Murray and Correia `_. + Parameters ---------- - X : :class:`torch.Tensor` of :class:`torch.double` - A tensor representing the x coordinate in the plane of the orbit. - Y : :class:`torch.Tensor` of :class:`torch.double` - A tensor representing the y coordinate in the plane of the orbit. + X : :class:`torch.Tensor` + A tensor representing the x coordinate in the plane of the sky. + Y : :class:`torch.Tensor` + A tensor representing the y coordinate in the plane of the sky. omega : float A tensor representing an argument of periastron [radians] Default 0.0. incl : float @@ -114,7 +128,7 @@ def observer_to_flat( Returns ------- - 2-tuple of :class:`torch.Tensor` of :class:`torch.double` + 2-tuple of :class:`torch.Tensor` Two tensors representing ``(x, y)`` in the flat frame. """ # Rotation matrices result in a *clockwise* rotation of the axes, @@ -136,7 +150,7 @@ def observer_to_flat( # we don't know Z, but we can solve some equations to find that # y = Y / cos(i), as expected by intuition y1 = y2 / np.cos(incl) - + # 3) inverse rotation about the z1 axis by an amount of omega cos_omega = np.cos(omega) sin_omega = np.sin(omega) diff --git a/src/mpol/gridding.py b/src/mpol/gridding.py index 0991aa04..bd317b45 100644 --- a/src/mpol/gridding.py +++ b/src/mpol/gridding.py @@ -1,19 +1,18 @@ from __future__ import annotations import warnings - -from typing import Any, Callable, Sequence +from collections.abc import Callable, Sequence +from typing import Any import numpy as np import numpy.typing as npt -from fast_histogram import histogram as fast_hist - import torch +from fast_histogram import histogram as fast_hist +from mpol import utils from mpol.coordinates import GridCoords -from mpol.exceptions import DataError, ThresholdExceededError, WrongDimensionError from mpol.datasets import GriddedDataset -from mpol import utils +from mpol.exceptions import DataError, ThresholdExceededError, WrongDimensionError def _check_data_inputs_2d( diff --git a/src/mpol/images.py b/src/mpol/images.py index ed120970..9e6b1310 100644 --- a/src/mpol/images.py +++ b/src/mpol/images.py @@ -1,16 +1,20 @@ -r"""The ``images`` module provides the core functionality of MPoL via +r"""The ``images`` module provides the core functionality of MPoL via :class:`mpol.images.ImageCube`.""" from __future__ import annotations +from collections.abc import Callable + import numpy as np +from typing import Any +import numpy.typing as npt + import torch import torch.fft # to avoid conflicts with old torch.fft *function* from torch import nn +import math -from typing import Any, Callable - -from mpol import utils +from mpol import constants, utils from mpol.coordinates import GridCoords @@ -60,11 +64,13 @@ def __init__( # The ``base_cube`` is already packed to make the Fourier transformation easier if base_cube is None: + # base_cube = -3 yields a nearly-blank cube after softplus, whereas + # base_cube = 0.0 yields a cube with avg value of ~0.7, which is too high self.base_cube = nn.Parameter( - torch.zeros( + -3 + * torch.ones( (self.nchan, self.coords.npix, self.coords.npix), requires_grad=True, - dtype=torch.double, ) ) @@ -111,14 +117,14 @@ class HannConvCube(nn.Module): \end{bmatrix} which is the 2D version of the discretely-sampled response function corresponding to - a Hann window, i.e., it is two 1D Hann windows multiplied together. This is a - convolutional kernel in the image plane, and so effectively acts as apodization - by a Hann window function in the Fourier domain. For more information, see the - following Wikipedia articles on `Window Functions - `_ in general and the `Hann Window + a Hann window, i.e., it is two 1D Hann windows multiplied together. This is a + convolutional kernel in the image plane, and so effectively acts as apodization + by a Hann window function in the Fourier domain. For more information, see the + following Wikipedia articles on `Window Functions + `_ in general and the `Hann Window `_ specifically. - The idea is that this layer would help naturally attenuate high spatial frequency + The idea is that this layer would help naturally attenuate high spatial frequency artifacts by baking in a natural apodization in the Fourier plane. Args: @@ -143,7 +149,7 @@ def __init__(self, nchan: int, requires_grad: bool = False) -> None: # bias has shape (nchan) # build out the discretely-sampled Hann kernel - spec = torch.tensor([0.25, 0.5, 0.25], dtype=torch.double) + spec = torch.tensor([0.25, 0.5, 0.25]) nugget = torch.outer(spec, spec) # shape (3,3) 2D Hann kernel exp = torch.unsqueeze(torch.unsqueeze(nugget, 0), 0) # shape (1, 1, 3, 3) weight = exp.repeat(nchan, 1, 1, 1) # shape (nchan, 1, 3, 3) @@ -154,9 +160,7 @@ def __init__(self, nchan: int, requires_grad: bool = False) -> None: ) # set the (untunable) weight # set the bias to zero - self.m.bias = nn.Parameter( - torch.zeros(nchan, dtype=torch.double), requires_grad=requires_grad - ) + self.m.bias = nn.Parameter(torch.zeros(nchan), requires_grad=requires_grad) def forward(self, cube: torch.Tensor) -> torch.Tensor: r"""Args: @@ -187,6 +191,226 @@ def forward(self, cube: torch.Tensor) -> torch.Tensor: return utils.sky_cube_to_packed_cube(conv_sky_cube) +class GaussConvImage(nn.Module): + r""" + This convolutional layer will convolve the input cube with a 2D Gaussian kernel. + The filter is the same for all channels in the input cube. + Because the operation is carried out in the image domain, note that it may become + computationally prohibitive for large kernel sizes. In that case, + :class:`mpol.images.GaussConvFourier` may be preferred. + + Parameters + ---------- + coords : :class:`mpol.coordinates.GridCoords` + an object instantiated from the GridCoords class, containing information about + the image `cell_size` and `npix`. + nchan : int + the number of channels in the base cube. Default = 1. + FWHM_maj: float, units of arcsec + the FWHH of the Gaussian along the major axis + FWHM_min: float, units of arcsec + the FWHM of the Gaussian along the minor axis + Omega: float, degrees + the rotation of the major axis of the PSF, in degrees East of North. 0 degrees rotation has the major axis aligned in the North-South direction. + requires_grad : bool + Should the kernel parameters be trainable? Most applications will want to use + `False`. + """ + + def __init__( + self, + coords: GridCoords, + nchan: int, + FWHM_maj: float, + FWHM_min: float, + Omega: float = 0, + requires_grad: bool = False, + ) -> None: + super().__init__() + + # convert FWHM to sigma and to radians + FWHM2sigma = 1 / (2 * np.sqrt(2 * np.log(2))) + + # In this routine, x, y are used in the same sense as the GridCoords + # object uses 'sky_x' and 'sky_y', i.e. x is l in arcseconds and + # y is m in arcseconds. + + # assumes major axis along m direction at 0 degrees rotation. + sigma_y = FWHM_maj * FWHM2sigma # arcsec + sigma_x = FWHM_min * FWHM2sigma # arcsec + + # calculate filter out to some Gaussian width, and make a kernel with an + # odd number of pixels + limit = 3.0 * sigma_y + npix_kernel = 1 + 2 * math.ceil(limit / coords.cell_size) + + if npix_kernel < 3: + raise RuntimeError( + """FWHM_maj is so small ({:} arcsec) relative to the + cell_size ({:} arcsec) that the convolutional kernel would only be + one pixel wide. Increase FWHM_maj or remove this + convolutional layer entirely""".format( + npix_kernel, coords.cell_size + ) + ) + + # create a grid to evaluate the 2D Gaussian, using an even number of + # pixels with the kernel centered (no max pixel) + kernel_centers = np.linspace(-limit, limit, num=npix_kernel) # [arcsec] + + # borrowed from GridCoords logic + x_centers_2D = np.tile(kernel_centers, (npix_kernel, 1)) # [arcsec] + sky_x_centers_2D = np.fliplr(x_centers_2D) + + sky_y_centers_2D = np.tile(kernel_centers, (npix_kernel, 1)).T # [arcsec] + + # evaluate Gaussian over grid + gauss = utils.sky_gaussian_arcsec( + sky_x_centers_2D, + sky_y_centers_2D, + 1.0, + delta_x=0.0, + delta_y=0.0, + sigma_x=sigma_x, + sigma_y=sigma_y, + Omega=Omega, + ) + # normalize kernel to keep total flux the same + gauss /= np.sum(gauss) + nugget = torch.tensor(gauss, dtype=torch.float32) + exp = torch.unsqueeze( + torch.unsqueeze(nugget, 0), 0 + ) # shape (1, 1, npix_kernel, npix_kernel) + weight = exp.repeat( + nchan, 1, 1, 1 + ) # shape (nchan, 1, npix_kernel, npix_kernel) + + # groups = nchan will give us the minimal set of filters we need + # somewhat confusingly, the neural network literature calls this + # a "depthwise" convolution. I think that "depthwise" is not meant to imply + # that there is now a consideration of the depth (e.g., color channel) + # dimension when before there wasn't. + # Rather, the emphasis is on the *wise*, as in "pairwise," in that + # each depth channel is treated individually with its own filter, rather than + # a filter that draws from multiple depth channels at once. + # I think a better name is "channel-separate" convolution as indicated in the + # "Understanding Deep Learning" textbook by Prince in Ch 10.6. + + # simple convolutional filter operates on per-channel basis + self.m = nn.Conv2d( + in_channels=nchan, + out_channels=nchan, + kernel_size=npix_kernel, + stride=1, + groups=nchan, + padding="same", + ) + + # weights has size (nchan, 1, npix_kernel, npix_kernel) + # bias has shape (nchan) + + # set the weight and bias + self.m.weight = nn.Parameter( + weight, requires_grad=requires_grad + ) # set the (untunable) weight + + # set the bias to zero + self.m.bias = nn.Parameter(torch.zeros(nchan), requires_grad=requires_grad) + + def forward(self, sky_cube: torch.Tensor) -> torch.Tensor: + r"""Args: + sky_cube (torch.double tensor, of shape ``(nchan, npix, npix)``): an image cube in sky format (note, not packed). + + Returns: + torch.complex tensor: the FFT of the image cube, in sky format and of + shape ``(nchan, npix, npix)`` + """ + convolved_sky: torch.Tensor + convolved_sky = self.m(sky_cube) + return convolved_sky + +class GaussConvFourier(nn.Module): + r""" + This layer will convolve the input cube with a (potentially non-circular) Gaussian + beam, using a Fourier strategy. + The size of the beam is set upon initialization of the layer. + + Parameters + ---------- + coords : :class:`mpol.coordinates.GridCoords` + an object instantiated from the GridCoords class, containing information about + the image `cell_size` and `npix`. + FWHM_maj: float, units of arcsec + the FWHH of the Gaussian along the major axis + FWHM_min: float, units of arcsec + the FWHM of the Gaussian along the minor axis + Omega: float, degrees + the rotation of the major axis of the PSF, in degrees East of North. 0 degrees rotation has the major axis aligned in the North-South direction. + """ + + def __init__( + self, + coords: GridCoords, + FWHM_maj: float, + FWHM_min: float, + Omega: float = 0) -> None: + super().__init__() + + self.coords = coords + self.FWHM_maj = FWHM_maj + self.FWHM_min = FWHM_min + self.Omega = Omega + + taper_2D = uv_gaussian_taper(self.coords, self.FWHM_maj, self.FWHM_min, self.Omega) + + # store taper to register so it transfers to GPU + self.register_buffer("taper_2D", torch.tensor(taper_2D, dtype=torch.float32)) + + def forward(self, packed_cube): + r""" + Convolve a packed_cube image with a 2D Gaussian PSF. Operation is carried out + in the Fourier domain using a Gaussian taper. + + Parameters + ---------- + packed_cube : :class:`torch.Tensor` type + shape ``(nchan, npix, npix)`` image cube in packed format. + + Returns + ------- + :class:`torch.Tensor` + The convolved cube in packed format. + """ + nchan, npix_m, npix_l = packed_cube.size() + assert ( + (npix_m == self.coords.npix) and (npix_l == self.coords.npix) + ), "packed_cube {:} does not have the same pixel dimensions as indicated by coords {:}".format( + packed_cube.size(), self.coords.npix + ) + + # in FFT packed format + # we're round-tripping, so we can ignore prefactors for correctness + # calling this `vis_like`, since it's not actually the vis + vis_like = torch.fft.fftn(packed_cube, dim=(1, 2)) + + # apply taper to packed image + tapered_vis = vis_like * torch.broadcast_to(self.taper_2D, packed_cube.size()) + + # iFFT back, ignoring prefactors for round-trip + convolved_packed_cube = torch.fft.ifftn(tapered_vis, dim=(1, 2)) + + # assert imaginaries are effectively zero, otherwise something went wrong + thresh = 1e-7 + assert ( + torch.max(convolved_packed_cube.imag) < thresh + ), "Round-tripped image contains max imaginary value {:} > {:} threshold, something may be amiss.".format( + torch.max(convolved_packed_cube.imag), thresh + ) + + r_cube: torch.Tensor = convolved_packed_cube.real + return r_cube + + class ImageCube(nn.Module): r""" The parameter set is the pixel values of the image cube itself. The pixels are @@ -225,18 +449,18 @@ def __init__( def forward(self, packed_cube: torch.Tensor) -> torch.Tensor: r""" - Pass the cube through as an identity operation, storing the value to the - internal buffer. After the cube has been passed through, convenience + Pass the cube through as an identity operation, storing the value to the + internal buffer. After the cube has been passed through, convenience instance attributes like `sky_cube` and `flux` will reflect the updated cube. Parameters ---------- - packed_cube : :class:`torch.Tensor` of type :class:`torch.double` + packed_cube : :class:`torch.Tensor` of type :class:`torch.double` 3D torch tensor of shape ``(nchan, npix, npix)``) in 'packed' format Returns ------- - :class:`torch.Tensor` of :class:`torch.double` type + :class:`torch.Tensor` tensor of shape ``(nchan, npix, npix)``), same as `cube` """ self.packed_cube = packed_cube @@ -285,10 +509,10 @@ def to_FITS( Returns: None """ - + from astropy import wcs from astropy.io import fits - + w = wcs.WCS(naxis=2) w.wcs.crpix = np.array([1, 1]) @@ -310,3 +534,51 @@ def to_FITS( hdul.writeto(fname, overwrite=overwrite) hdul.close() + + +def uv_gaussian_taper( + coords: GridCoords, FWHM_maj: float, FWHM_min: float, Omega: float +) -> npt.NDArray[np.floating[Any]]: + r""" + Compute a packed Gaussian taper in the Fourier domain, to multiply against a packed + visibility cube. While similar to :meth:`mpol.utils.fourier_gaussian_lambda_arcsec`, + this routine delivers a visibility-plane taper with maximum amplitude normalized + to 1.0. + + Parameters + ---------- + coords: :class:`mpol.coordinates.GridCoords` + object indicating image and Fourier grid specifications. + FWHM_maj: float, units of arcsec + the FWHH of the Gaussian along the major axis + FWHM_min: float, units of arcsec + the FWHM of the Gaussian along the minor axis + Omega: float, degrees + the rotation of the major axis of the PSF, in degrees East of North. 0 degrees rotation has the major axis aligned in the East-West direction. + + Returns + ------- + :class:`np.ndarray` , shape ``(npix, npix)`` + The Gaussian taper in packed format. + """ + + # convert FWHM to sigma and to radians + FWHM2sigma = 1 / (2 * np.sqrt(2 * np.log(2))) + sigma_l = FWHM_maj * FWHM2sigma * constants.arcsec # radians + sigma_m = FWHM_min * FWHM2sigma * constants.arcsec # radians + + u = coords.packed_u_centers_2D + v = coords.packed_v_centers_2D + + # calculate primed rotated coordinates + Omega_d = Omega * constants.deg + up = u * np.cos(Omega_d) - v * np.sin(Omega_d) + vp = u * np.sin(Omega_d) + v * np.cos(Omega_d) + + # calculate the Fourier Gaussian + taper_2D: npt.NDArray[np.floating[Any]] + taper_2D = np.exp(-2 * np.pi**2 * (sigma_l**2 * up**2 + sigma_m**2 * vp**2)) + + # a flux-conserving taper must have an amplitude of 1 at the origin. + + return taper_2D \ No newline at end of file diff --git a/src/mpol/input_output.py b/src/mpol/input_output.py index 3ff1be3e..1655773c 100644 --- a/src/mpol/input_output.py +++ b/src/mpol/input_output.py @@ -1,7 +1,7 @@ -import numpy as np - +import numpy as np from astropy.io import fits + class ProcessFitsImage: """ Utilities for loading and retrieving metrics of a .fits image @@ -21,7 +21,7 @@ def __init__(self, filename, channel=0): def get_extent(self, header): """Get extent (in RA and Dec, units of [arcsec]) of image""" - + # get the coordinate labels nx = header["NAXIS1"] ny = header["NAXIS2"] @@ -55,7 +55,7 @@ def get_extent(self, header): def get_beam(self, hdu_list, header): - """Get the major and minor widths [arcsec], and position angle, of a + """Get the major and minor widths [arcsec], and position angle, of a clean beam""" if header.get("CASAMBM") is not None: @@ -74,12 +74,12 @@ def get_beam(self, hdu_list, header): def get_image(self, beam=True): - """Load a .fits image and return as a numpy array. Also return image + """Load a .fits image and return as a numpy array. Also return image extent and optionally (`beam`) the clean beam dimensions""" hdu_list = fits.open(self._fits_file) hdu = hdu_list[0] - + if len(hdu.data.shape) in [3, 4]: image = hdu.data[self._channel] # first channel else: @@ -98,4 +98,3 @@ def get_image(self, beam=True): return image, ext, self.get_beam(hdu_list, header) else: return image, ext - \ No newline at end of file diff --git a/src/mpol/losses.py b/src/mpol/losses.py index ed0cfba7..00e4c1e6 100644 --- a/src/mpol/losses.py +++ b/src/mpol/losses.py @@ -1,9 +1,10 @@ +from typing import Optional + import numpy as np import torch from mpol import constants from mpol.datasets import GriddedDataset -from typing import Optional def _chi_squared( @@ -29,12 +30,12 @@ def _chi_squared( array of the model values representing :math:`\boldsymbol{V}` data_vis : :class:`torch.Tensor` of :class:`torch.complex` array of the data values representing :math:`M` - weight : :class:`torch.Tensor` of :class:`torch.double` + weight : :class:`torch.Tensor` array of weight values representing :math:`w_i` Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` the :math:`\chi^2` likelihood, summed over all dimensions of input array. """ @@ -78,12 +79,12 @@ def r_chi_squared( array of the model values representing :math:`\boldsymbol{V}` data_vis : :class:`torch.Tensor` of :class:`torch.complex` array of the data values representing :math:`M` - weight : :class:`torch.Tensor` of :class:`torch.double` + weight : :class:`torch.Tensor` array of weight values representing :math:`w_i` Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` the :math:`\chi^2_\mathrm{R}`, summed over all dimensions of input array. """ @@ -115,7 +116,7 @@ def r_chi_squared_gridded( Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` the :math:`\chi^2_\mathrm{R}` value summed over all input dimensions """ model_vis = griddedDataset(modelVisibilityCube) @@ -165,12 +166,12 @@ def log_likelihood( array of the model values representing :math:`\boldsymbol{V}` data_vis : :class:`torch.Tensor` of :class:`torch.complex128` array of the data values representing :math:`M` - weight : :class:`torch.Tensor` of :class:`torch.double` + weight : :class:`torch.Tensor` array of weight values representing :math:`w_i` Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` the :math:`\ln\mathcal{L}` log likelihood, summed over all dimensions of input array. """ @@ -207,7 +208,7 @@ def log_likelihood_gridded( Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` the :math:`\ln\mathcal{L}` value, summed over all dimensions of input data. """ @@ -246,12 +247,12 @@ def neg_log_likelihood_avg( array of the model values representing :math:`\boldsymbol{V}` data_vis : :class:`torch.Tensor` of :class:`torch.complex` array of the data values representing :math:`M` - weight : :class:`torch.Tensor` of :class:`torch.double` + weight : :class:`torch.Tensor` array of weight values representing :math:`w_i` Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` the average of the negative log likelihood, summed over all dimensions of input array. """ @@ -274,9 +275,9 @@ def entropy( Parameters ---------- - cube : :class:`torch.Tensor` of :class:`torch.double` + cube : :class:`torch.Tensor` pixel values must be positive :math:`I_i > 0` for all :math:`i` - prior_intensity : :class:`torch.Tensor` of :class:`torch.double` + prior_intensity : :class:`torch.Tensor` the prior value :math:`p` to calculate entropy against. Tensors of any shape are allowed so long as they will broadcast to the shape of the cube under division (`/`). @@ -286,7 +287,7 @@ def entropy( Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` entropy loss """ # check to make sure image is positive, otherwise raise an error @@ -312,7 +313,7 @@ def TV_image(sky_cube: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor: Parameters ---------- - sky_cube: 3D :class:`torch.Tensor` of :class:`torch.double` + sky_cube: 3D :class:`torch.Tensor` the image cube array :math:`I_{lmv}`, where :math:`l` is R.A. in :math:`ndim=3`, :math:`m` is DEC in :math:`ndim=2`, and :math:`v` is the channel (velocity or frequency) dimension in @@ -324,7 +325,7 @@ def TV_image(sky_cube: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor: Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` total variation loss """ @@ -347,7 +348,7 @@ def TV_channel(cube: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor: Parameters ---------- - cube: :class:`torch.Tensor` of :class:`torch.double` + cube: :class:`torch.Tensor` the image cube array :math:`I_{lmv}` epsilon: float a softening parameter in units of [:math:`\mathrm{Jy}/\mathrm{arcsec}^2`]. @@ -356,7 +357,7 @@ def TV_channel(cube: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor: Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` total variation loss """ # calculate the difference between the n+1 cube and the n cube @@ -382,7 +383,7 @@ def TSV(sky_cube: torch.Tensor) -> torch.Tensor: Parameters ---------- - sky_cube :class:`torch.Tensor` of :class:`torch.double` + sky_cube :class:`torch.Tensor` the image cube array :math:`I_{lmv}`, where :math:`l` is R.A. in :math:`ndim=3`, :math:`m` is DEC in :math:`ndim=2`, and :math:`v` is the channel (velocity or frequency) dimension in @@ -390,7 +391,7 @@ def TSV(sky_cube: torch.Tensor) -> torch.Tensor: Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` total square variation loss """ @@ -417,7 +418,7 @@ def sparsity(cube: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.T Parameters ---------- - cube : :class:`torch.Tensor` of :class:`torch.double` + cube : :class:`torch.Tensor` the image cube array :math:`I_{lmv}` mask : :class:`torch.Tensor` of :class:`torch.bool` tensor array the same shape as ``cube``. The sparsity prior @@ -426,7 +427,7 @@ def sparsity(cube: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.T Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` sparsity loss calculated where ``mask == True`` """ @@ -457,7 +458,7 @@ def UV_sparsity( Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` UV sparsity loss above :math:`q_\mathrm{max}` """ @@ -497,16 +498,16 @@ def PSD(qs: torch.Tensor, psd: torch.Tensor, l: torch.Tensor) -> torch.Tensor: Parameters ---------- - qs : :class:`torch.Tensor` of :class:`torch.double` + qs : :class:`torch.Tensor` the radial UV coordinate (in :math:`\lambda`) - psd : :class:`torch.Tensor` of :class:`torch.double` + psd : :class:`torch.Tensor` the power spectral density cube - l : :class:`torch.Tensor` of :class:`torch.double` + l : :class:`torch.Tensor` the correlation length in the image plane (in arcsec) Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` the loss calculated using the power spectral density """ @@ -530,12 +531,12 @@ def edge_clamp(cube: torch.Tensor) -> torch.Tensor: Parameters ---------- - cube: :class:`torch.Tensor` of :class:`torch.double` + cube: :class:`torch.Tensor` the image cube array :math:`I_{lmv}` Returns ------- - :class:`torch.Tensor` of :class:`torch.double` + :class:`torch.Tensor` edge loss """ diff --git a/src/mpol/onedim.py b/src/mpol/onedim.py index b9d2a0f9..408edae8 100644 --- a/src/mpol/onedim.py +++ b/src/mpol/onedim.py @@ -1,4 +1,5 @@ import numpy as np + from mpol.utils import torch2npy diff --git a/src/mpol/plot.py b/src/mpol/plot.py index 9d48da41..afdf0872 100644 --- a/src/mpol/plot.py +++ b/src/mpol/plot.py @@ -1,15 +1,10 @@ -import numpy as np -import matplotlib.pyplot as plt import matplotlib.colors as mco -from matplotlib.patches import Ellipse -import torch - +import matplotlib.pyplot as plt +import numpy as np from astropy.visualization.mpl_normalize import simple_norm -from mpol.gridding import DirtyImager from mpol.onedim import radialI, radialV -from mpol.utils import loglinspace, torch2npy, packed_cube_to_sky_cube -from mpol.input_output import ProcessFitsImage +from mpol.utils import loglinspace, packed_cube_to_sky_cube, torch2npy def get_image_cmap_norm( @@ -265,9 +260,9 @@ def vis_histogram_fig( else: supported_q = ["count", "weight", "vis_real", "vis_imag"] raise ValueError( - "`bin_quantity` ({}) must be one of " - "{}, or a user-provided numpy " - " array".format(bin_quantity, supported_q) + f"`bin_quantity` ({bin_quantity}) must be one of " + f"{supported_q}, or a user-provided numpy " + " array" ) # buffer to include longest baselines in last bin @@ -279,7 +274,7 @@ def vis_histogram_fig( bin_lab = None if all(np.diff(q_edges1d) == np.diff(q_edges1d)[0]): - bin_lab = r"Bin size {:.0f} k$\lambda$".format(np.diff(q_edges1d)[0]) + bin_lab = rf"Bin size {np.diff(q_edges1d)[0]:.0f} k$\lambda$" # 2d histogram bins if q_edges is None: @@ -891,7 +886,7 @@ def vis_1d_fig( # deproject the (u,v) points uu, vv, _ = deproject(uu, vv, geom["incl"], geom["Omega"]) - + # if the source is optically thick, rescale the deprojected V(q) if rescale_flux: V.real /= np.cos(geom["incl"] * np.pi / 180) @@ -1008,7 +1003,7 @@ def radial_fig( channel=0, save_prefix=None, ): - """ + r""" Figure for analysis of 1D (radial) brightness profile of MPoL model image, using a user-supplied geometry. @@ -1083,7 +1078,7 @@ def radial_fig( # deproject the observed (u,v) points u, v, _ = deproject(u, v, geom["incl"], geom["Omega"]) - + # if the source is optically thick, rescale the deprojected V(q) if rescale_flux: V.real /= np.cos(geom["incl"] * np.pi / 180) diff --git a/src/mpol/precomposed.py b/src/mpol/precomposed.py index 8b07b281..24fab40a 100644 --- a/src/mpol/precomposed.py +++ b/src/mpol/precomposed.py @@ -1,12 +1,10 @@ +import typing + import torch +from mpol import fourier, images from mpol.coordinates import GridCoords -from mpol import fourier -from mpol import images - -import typing - class GriddedNet(torch.nn.Module): r""" @@ -36,7 +34,7 @@ class GriddedNet(torch.nn.Module): a pre-packed base cube to initialize the model with. If None, assumes ``torch.zeros``. - + After the object is initialized, instance variables can be accessed, for example :ivar bcube: the :class:`~mpol.images.BaseCube` instance @@ -88,7 +86,7 @@ def forward(self) -> torch.Tensor: def predict_loose_visibilities( self, uu: torch.Tensor, vv: torch.Tensor ) -> torch.Tensor: - """ + r""" Use the :class:`mpol.fourier.NuFFT` to calculate loose model visibilities from the cube stored to ``self.icube.packed_cube``. diff --git a/src/mpol/training.py b/src/mpol/training.py index cc75bc8f..2bed8999 100644 --- a/src/mpol/training.py +++ b/src/mpol/training.py @@ -1,10 +1,6 @@ import logging -import numpy as np -import torch -from mpol.losses import TSV, TV_image, entropy, r_chi_squared_gridded, sparsity -from mpol.plot import train_diagnostics_fig -from mpol.utils import torch2npy +import torch def train_to_dirty_image(model, imager, robust=0.5, learn_rate=100, niter=1000): diff --git a/src/mpol/utils.py b/src/mpol/utils.py index b7912f56..9173ff3b 100644 --- a/src/mpol/utils.py +++ b/src/mpol/utils.py @@ -1,10 +1,10 @@ import math +from typing import Any import numpy as np +import numpy.typing as npt import torch -from typing import Any -import numpy.typing as npt from mpol.constants import arcsec, cc, deg, kB @@ -221,10 +221,10 @@ def check_baselines(q, min_feasible_q=1e3, max_feasible_q=1e8): if max(q) > max_feasible_q: raise Warning( - "Maximum baseline of {:.1e} is > maximum expected " - "value of {:.1e}. Baselines must be in units of " + f"Maximum baseline of {max(q):.1e} is > maximum expected " + f"value of {max_feasible_q:.1e}. Baselines must be in units of " "[klambda], but it looks like they're in " - "[lambda].".format(max(q), max_feasible_q) + "[lambda]." ) if min(q) > min_feasible_q * 1e3: @@ -291,7 +291,7 @@ def get_optimal_image_properties( image_width : float, unit = arcsec Desired width of the image (for a square image of size `image_width` :math:`\times` `image_width`). - u, v : :class:`torch.Tensor` of :class:`torch.double`, unit = :math:`\lambda` + u, v : :class:`torch.Tensor` , unit = :math:`\lambda` `u` and `v` baselines. Returns @@ -342,8 +342,8 @@ def sky_gaussian_radians( Omega: float, ) -> npt.NDArray[np.floating[Any]]: r""" - Calculates a 2D Gaussian on the sky plane with inputs in radians. The Gaussian is - centered at ``delta_l, delta_m``, has widths of ``sigma_l, sigma_m``, and is + Calculates a 2D Gaussian on the sky plane with inputs in radians. The Gaussian is + centered at ``delta_l, delta_m``, has widths of ``sigma_l, sigma_m``, and is rotated ``Omega`` degrees East of North. To evaluate the Gaussian, internally first we translate to center @@ -364,8 +364,8 @@ def sky_gaussian_radians( .. math:: - f_\mathrm{g}(l,m) = a \exp \left ( - \frac{1}{2} \left - [ \left (\frac{l''}{\sigma_l} \right)^2 + \left( \frac{m''}{\sigma_m} + f_\mathrm{g}(l,m) = a \exp \left ( - \frac{1}{2} \left + [ \left (\frac{l''}{\sigma_l} \right)^2 + \left( \frac{m''}{\sigma_m} \right )^2 \right ] \right ) Args: @@ -375,7 +375,7 @@ def sky_gaussian_radians( delta_l : offset [radians] delta_m : offset [radians] sigma_l : width [radians] - sigma_M : width [radians] + sigma_m : width [radians] Omega : position angle of ascending node [degrees] east of north. Returns: @@ -448,11 +448,11 @@ def fourier_gaussian_lambda_radians( Omega: float, ) -> npt.NDArray[np.floating[Any]]: r""" - Calculate the Fourier plane Gaussian :math:`F_\mathrm{g}(u,v)` corresponding to the - Sky plane Gaussian :math:`f_\mathrm{g}(l,m)` in - :func:`~mpol.utils.sky_gaussian_radians`, using analytical relationships. The - Fourier Gaussian is parameterized using the sky plane centroid - (``delta_l, delta_m``), widths (``sigma_l, sigma_m``) and rotation (``Omega``). + Calculate the Fourier plane Gaussian :math:`F_\mathrm{g}(u,v)` corresponding to the + Sky plane Gaussian :math:`f_\mathrm{g}(l,m)` in + :func:`~mpol.utils.sky_gaussian_radians`, using analytical relationships. The + Fourier Gaussian is parameterized using the sky plane centroid + (``delta_l, delta_m``), widths (``sigma_l, sigma_m``) and rotation (``Omega``). Assumes that ``a`` was in units of :math:`\mathrm{Jy}/\mathrm{steradian}`. Args: @@ -468,12 +468,12 @@ def fourier_gaussian_lambda_radians( Returns: 2D Gaussian evaluated at input args - The following is a description of how we derived the analytical relationships. In - what follows, all :math:`l` and :math:`m` coordinates are assumed to be in units - of radians and all :math:`u` and :math:`v` coordinates are assumed to be in units + The following is a description of how we derived the analytical relationships. In + what follows, all :math:`l` and :math:`m` coordinates are assumed to be in units + of radians and all :math:`u` and :math:`v` coordinates are assumed to be in units of :math:`\lambda`. - We start from Fourier dual relationships in Bracewell's `The Fourier Transform and + We start from Fourier dual relationships in Bracewell's `The Fourier Transform and Its Applications `_ .. math:: @@ -494,8 +494,8 @@ def fourier_gaussian_lambda_radians( respectively. The sky-plane Gaussian has a maximum value of :math:`a`. - We will use the similarity, rotation, and shift theorems to turn :math:`f_0` into - a form matching :math:`f_\mathrm{g}`, which simultaneously turns :math:`F_0` into + We will use the similarity, rotation, and shift theorems to turn :math:`f_0` into + a form matching :math:`f_\mathrm{g}`, which simultaneously turns :math:`F_0` into :math:`F_\mathrm{g}(u,v)`. The similarity theorem states that (in 1D) @@ -511,30 +511,30 @@ def fourier_gaussian_lambda_radians( f_1(l, m) = a \exp \left(-\frac{1}{2} \left [\left(\frac{l}{\sigma_l}\right)^2 + \left( \frac{m}{\sigma_m} \right)^2 \right] \right). - i.e., something we might call a normalized Gaussian function. Phrased in terms of + i.e., something we might call a normalized Gaussian function. Phrased in terms of :math:`f_0`, :math:`f_1` is .. math:: - f_1(l, m) = f_0\left ( \frac{l}{\sigma_l \sqrt{2 \pi}},\, + f_1(l, m) = f_0\left ( \frac{l}{\sigma_l \sqrt{2 \pi}},\, \frac{m}{\sigma_m \sqrt{2 \pi}}\right). Therefore, according to the similarity theorem, the equivalent :math:`F_1(u,v)` is .. math:: - F_1(u, v) = \sigma_l \sigma_m 2 \pi F_0 \left( \sigma_l \sqrt{2 \pi} u,\, + F_1(u, v) = \sigma_l \sigma_m 2 \pi F_0 \left( \sigma_l \sqrt{2 \pi} u,\, \sigma_m \sqrt{2 \pi} v \right), or .. math:: - F_1(u, v) = a \sigma_l \sigma_m 2 \pi \exp \left ( -2 \pi^2 [\sigma_l^2 u^2 + + F_1(u, v) = a \sigma_l \sigma_m 2 \pi \exp \left ( -2 \pi^2 [\sigma_l^2 u^2 + \sigma_m^2 v^2] \right). - Next, we rotate the Gaussian to match the sky plane rotation. A rotation - :math:`\Omega` in the sky plane is carried out in the same direction in the + Next, we rotate the Gaussian to match the sky plane rotation. A rotation + :math:`\Omega` in the sky plane is carried out in the same direction in the Fourier plane, .. math:: @@ -549,8 +549,8 @@ def fourier_gaussian_lambda_radians( f_2(l, m) = f_1(l', m') \\ F_2(u, v) = F_1(u', m') - Finally, we translate the sky plane Gaussian by amounts :math:`\delta_l`, - :math:`\delta_m`, which corresponds to a phase shift in the Fourier plane Gaussian. + Finally, we translate the sky plane Gaussian by amounts :math:`\delta_l`, + :math:`\delta_m`, which corresponds to a phase shift in the Fourier plane Gaussian. The image plane translation is .. math:: @@ -563,16 +563,16 @@ def fourier_gaussian_lambda_radians( F_3(u,v) = \exp\left (- 2 i \pi [\delta_l u + \delta_m v] \right) F_2(u,v) - We have arrived at the corresponding Fourier Gaussian, :math:`F_\mathrm{g}(u,v) = + We have arrived at the corresponding Fourier Gaussian, :math:`F_\mathrm{g}(u,v) = F_3(u,v)`. The simplified equation is .. math:: - F_\mathrm{g}(u,v) = a \sigma_l \sigma_m 2 \pi \exp \left ( - -2 \pi^2 \left [\sigma_l^2 u'^2 + \sigma_m^2 v'^2 \right] + F_\mathrm{g}(u,v) = a \sigma_l \sigma_m 2 \pi \exp \left ( + -2 \pi^2 \left [\sigma_l^2 u'^2 + \sigma_m^2 v'^2 \right] - 2 i \pi \left [\delta_l u + \delta_m v \right] \right). - N.B. that we have mixed primed (:math:`u'`) and unprimed (:math:`u`) coordinates in + N.B. that we have mixed primed (:math:`u'`) and unprimed (:math:`u`) coordinates in the same equation for brevity. Finally, the same Fourier dual relationship holds diff --git a/test/conftest.py b/test/conftest.py index ceddf634..fa151838 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,14 +1,13 @@ +from importlib.resources import files + import numpy as np import pytest import torch import visread.process from astropy.utils.data import download_file - from mpol import coordinates, fourier, gridding, images, utils from mpol.__init__ import zenodo_record -from importlib.resources import files - # private variables to this module _npz_path = files("mpol.data").joinpath("mock_data.npz") _nchan = 4 @@ -22,18 +21,24 @@ def img2D_butterfly(): """Return the 2D source image of the butterfly, for use as a test image cube.""" archive = np.load(_npz_path) - img = np.float64(archive["img"]) + img = archive["img"] # assuming we're going to go with _cell_size, set the total flux of this image # total flux should be 0.253 Jy from MPoL-examples. return img +@pytest.fixture(scope="session") +def sky_cube(img2D_butterfly): + """Create a sky tensor image cube from the butterfly.""" + print("npix packed cube", img2D_butterfly.shape) + # tile to some nchan, npix, npix + sky_cube = torch.tile(torch.from_numpy(img2D_butterfly), (_nchan, 1, 1)) + return sky_cube @pytest.fixture(scope="session") def packed_cube(img2D_butterfly): """Create a packed tensor image cube from the butterfly.""" - # now (1, npix, npix) print("npix packed cube", img2D_butterfly.shape) # tile to some nchan, npix, npix sky_cube = torch.tile(torch.from_numpy(img2D_butterfly), (_nchan, 1, 1)) @@ -45,7 +50,7 @@ def packed_cube(img2D_butterfly): def baselines_m(): "Return the mock baselines (in meters) produced from the IM Lup DSHARP dataset." archive = np.load(_npz_path) - return np.float64(archive["uu"]), np.float64(archive["vv"]) + return archive["uu"], archive["vv"] @pytest.fixture(scope="session") @@ -199,7 +204,6 @@ def mock_1d_vis_model(mock_1d_archive): geom = m["geometry"] geom = geom[()] - Vtrue = m["vis"] Vtrue_dep = m["vis_dep"] q_dep = m["baselines_dep"] diff --git a/test/coordinates_test.py b/test/coordinates_test.py index 9b027461..1b22b351 100644 --- a/test/coordinates_test.py +++ b/test/coordinates_test.py @@ -1,9 +1,8 @@ import matplotlib.pyplot as plt -import torch +import numpy as np import pytest - -from mpol import coordinates -from mpol.constants import * +import torch +from mpol import constants, coordinates from mpol.exceptions import CellSizeError @@ -47,7 +46,7 @@ def test_grid_coords_plot_2D_uvq_sky(tmp_path): im = ax[2].imshow(coords.ground_q_centers_2D, **ikw) plt.colorbar(im, ax=ax[2]) - for a, t in zip(ax, ["u", "v", "q"]): + for a, t in zip(ax, ["u", "v", "q"], strict=False): a.set_title(t) fig.savefig(tmp_path / "sky_uvq.png", dpi=300) @@ -68,7 +67,7 @@ def test_grid_coords_plot_2D_uvq_packed(tmp_path): im = ax[2].imshow(coords.packed_q_centers_2D, **ikw) plt.colorbar(im, ax=ax[2]) - for a, t in zip(ax, ["u", "v", "q"]): + for a, t in zip(ax, ["u", "v", "q"], strict=False): a.set_title(t) fig.savefig(tmp_path / "packed_uvq.png", dpi=300) @@ -117,7 +116,7 @@ def test_tile_vs_meshgrid_implementation(): coords = coordinates.GridCoords(cell_size=0.05, npix=800) x_centers_2d, y_centers_2d = np.meshgrid( - coords.l_centers / arcsec, coords.m_centers / arcsec, indexing="xy" + coords.l_centers / constants.arcsec, coords.m_centers / constants.arcsec, indexing="xy" ) ground_u_centers_2D, ground_v_centers_2D = np.meshgrid( diff --git a/test/crossval_test.py b/test/crossval_test.py index db4d2566..4f67faae 100644 --- a/test/crossval_test.py +++ b/test/crossval_test.py @@ -2,13 +2,10 @@ import matplotlib.pyplot as plt import numpy as np -import torch # from mpol.crossval import CrossValidate, DartboardSplitGridded, RandomCellSplitGridded from mpol.crossval import DartboardSplitGridded, RandomCellSplitGridded from mpol.datasets import Dartboard -from mpol.plot import split_diagnostics_fig - # def test_crossvalclass_split_dartboard(coords, imager, dataset, generic_parameters): # # using the CrossValidate class, split a dataset into train/test subsets @@ -147,4 +144,4 @@ def test_dartboardsplit_iterate_masks(coords, dataset, tmp_path): ax[0, 0].set_title("train") ax[0, 1].set_title("test") - fig.savefig(tmp_path / "masks", dpi=300) \ No newline at end of file + fig.savefig(tmp_path / "masks", dpi=300) diff --git a/test/datasets_test.py b/test/datasets_test.py index 49245148..34923ed5 100644 --- a/test/datasets_test.py +++ b/test/datasets_test.py @@ -4,7 +4,6 @@ import matplotlib.pyplot as plt import numpy as np import torch - from mpol import datasets, fourier, images @@ -16,11 +15,9 @@ def test_index(coords, dataset): # create a mock cube that includes negative values nchan = dataset.nchan mean = torch.full( - (nchan, coords.npix, coords.npix), fill_value=-0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=-0.5) std = torch.full( - (nchan, coords.npix, coords.npix), fill_value=0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=0.5) # tensor base_cube = torch.normal(mean=mean, std=std) diff --git a/test/fftshift_test.py b/test/fftshift_test.py index 909e11bf..52ccca6b 100644 --- a/test/fftshift_test.py +++ b/test/fftshift_test.py @@ -2,8 +2,6 @@ import numpy as np import torch -import mpol.utils - def test_mpol_fftshift(tmp_path): diff --git a/test/fourier_test.py b/test/fourier_test.py index 0d9b31d1..e3ef6932 100644 --- a/test/fourier_test.py +++ b/test/fourier_test.py @@ -1,10 +1,9 @@ import matplotlib.pyplot as plt import numpy as np import torch +from mpol import fourier, utils from pytest import approx -from mpol import fourier, images, utils - def test_fourier_cube(coords, tmp_path): # test image packing @@ -142,17 +141,14 @@ def test_predict_vis_nufft(coords, baselines_1D): nchan = 10 - # instantiate an BaseCube layer filled with zeros - basecube = images.BaseCube(coords=coords, nchan=nchan, pixel_mapping=lambda x: x) - imagecube = images.ImageCube(coords=coords, nchan=nchan) - # we have a multi-channel cube, but only sent single-channel uu and vv # coordinates. The expectation is that TorchKbNufft will parallelize these layer = fourier.NuFFT(coords=coords, nchan=nchan) # predict the values of the cube at the u,v locations - output = layer(imagecube(basecube()), uu, vv) + blank_packed_img = torch.zeros((nchan, coords.npix, coords.npix)) + output = layer(blank_packed_img, uu, vv) # make sure we got back the number of visibilities we expected assert output.shape == (nchan, len(uu)) @@ -172,62 +168,55 @@ def test_predict_vis_nufft_cached(coords, baselines_1D): nchan = 10 - # instantiate an ImageCube layer filled with zeros - # instantiate an BaseCube layer filled with zeros - basecube = images.BaseCube(coords=coords, nchan=nchan, pixel_mapping=lambda x: x) - imagecube = images.ImageCube(coords=coords, nchan=nchan) - # we have a multi-channel cube, but sent only single-channel uu and vv # coordinates. The expectation is that TorchKbNufft will parallelize these - layer = fourier.NuFFTCached(coords=coords, nchan=nchan, uu=uu, vv=vv) # predict the values of the cube at the u,v locations - output = layer(imagecube(basecube())) + blank_packed_img = torch.zeros((nchan, coords.npix, coords.npix)) + output = layer(blank_packed_img) # make sure we got back the number of visibilities we expected assert output.shape == (nchan, len(uu)) # if the image cube was filled with zeros, then we should make sure this is true assert output.detach().numpy() == approx( - np.zeros((nchan, len(uu)), dtype=np.complex128) + np.zeros((nchan, len(uu))) ) def test_nufft_cached_predict_GPU(coords, baselines_1D): - if not torch.cuda.is_available(): - pass + if torch.cuda.is_available(): + device = torch.device("cuda") else: - device = torch.device("cuda:0") - - # just see that we can load the layer and get something through without error - # for a very simple blank function + return - # load some data - uu, vv = baselines_1D + # just see that we can load the layer and get something through without error + # for a very simple blank function - nchan = 10 + # load some data + uu, vv = baselines_1D - # instantiate an ImageCube layer filled with zeros and send to GPU - imagecube = images.ImageCube(coords=coords, nchan=nchan).to(device=device) + nchan = 10 - # we have a multi-channel cube, but only sent single-channel uu and vv - # coordinates. The expectation is that TorchKbNufft will parallelize these + # we have a multi-channel cube, but only sent single-channel uu and vv + # coordinates. The expectation is that TorchKbNufft will parallelize these - layer = fourier.NuFFTCached(coords=coords, nchan=nchan, uu=uu, vv=vv).to( - device=device - ) + layer = fourier.NuFFTCached(coords=coords, nchan=nchan, uu=uu, vv=vv).to( + device=device + ) - # predict the values of the cube at the u,v locations - output = layer(imagecube()) + # predict the values of the cube at the u,v locations + blank_packed_img = torch.zeros((nchan, coords.npix, coords.npix)).to(device=device) + output = layer(blank_packed_img) - # make sure we got back the number of visibilities we expected - assert output.shape == (nchan, len(uu)) + # make sure we got back the number of visibilities we expected + assert output.shape == (nchan, len(uu)) - # if the image cube was filled with zeros, then we should make sure this is true - assert output.cpu().detach().numpy() == approx( - np.zeros((nchan, len(uu)), dtype=np.complex128) - ) + # if the image cube was filled with zeros, then we should make sure this is true + assert output.cpu().detach().numpy() == approx( + np.zeros((nchan, len(uu)), dtype=np.complex128) + ) def test_nufft_accuracy_single_chan(coords, baselines_1D, tmp_path): @@ -325,12 +314,10 @@ def test_nufft_cached_accuracy_single_chan(coords, baselines_1D, tmp_path): img_packed = utils.sky_gaussian_arcsec( coords.packed_x_centers_2D, coords.packed_y_centers_2D, **kw ) - img_packed_tensor = torch.tensor(img_packed[np.newaxis, :, :], requires_grad=True) + img_packed_tensor = torch.tensor(img_packed[np.newaxis, :, :], requires_grad=True, dtype=torch.float32) # use the NuFFT to predict the values of the cube at the u,v locations - num_output = ( - layer(img_packed_tensor)[0] - ) # take the channel dim out + num_output = layer(img_packed_tensor)[0] # take the channel dim out # calculate the values analytically an_output = utils.fourier_gaussian_lambda_arcsec(uu, vv, **kw) @@ -341,7 +328,6 @@ def test_nufft_cached_accuracy_single_chan(coords, baselines_1D, tmp_path): max = torch.max(torch.abs(num_output)) print(max_diff, max) - # collapse the function into 1D by doing q qq = utils.torch2npy(torch.hypot(uu, vv)) @@ -404,7 +390,7 @@ def test_nufft_cached_accuracy_coil_broadcast(coords, baselines_1D): # broadcast to 5 channels -- the image will be the same for each img_packed_tensor = torch.tensor( img_packed[np.newaxis, :, :] * np.ones((nchan, coords.npix, coords.npix)), - requires_grad=True, + requires_grad=True, dtype=torch.float32 ) # use the NuFFT to predict the values of the cube at the u,v locations diff --git a/test/geometry_test.py b/test/geometry_test.py index b19b0c2a..84ca6d42 100644 --- a/test/geometry_test.py +++ b/test/geometry_test.py @@ -1,13 +1,13 @@ +import numpy as np import torch +from mpol import geometry from pytest import approx -import numpy as np -from mpol import geometry def test_rotate_points(): - ''' + """ Test rotation from flat 2D frame to observer frame and back - ''' + """ xs = torch.tensor([0.0, 1.0, 2.0]) ys = torch.tensor([1.0, -1.0, 2.0]) @@ -29,14 +29,13 @@ def test_rotate_points(): def test_rotate_coords(coords): - + omega = 35. * np.pi/180 incl = 30. * np.pi/180 Omega = 210. * np.pi/180 x, y = geometry.observer_to_flat(coords.sky_x_centers_2D, coords.sky_y_centers_2D, omega=omega, incl=incl, Omega=Omega) - + print(x, y) - - \ No newline at end of file + diff --git a/test/gridder_dataset_export_test.py b/test/gridder_dataset_export_test.py index 813172d5..9938571b 100644 --- a/test/gridder_dataset_export_test.py +++ b/test/gridder_dataset_export_test.py @@ -1,8 +1,6 @@ import numpy as np import pytest - from mpol import coordinates, gridding -from mpol.constants import * def test_cell_variance_error_pytorch(mock_dataset_np): diff --git a/test/gridder_gridding_test.py b/test/gridder_gridding_test.py index fe8c14f9..8a9fa079 100644 --- a/test/gridder_gridding_test.py +++ b/test/gridder_gridding_test.py @@ -1,9 +1,7 @@ import matplotlib.pyplot as plt import numpy as np import pytest - from mpol import coordinates, gridding -from mpol.constants import * def test_average_cont(coords, mock_dataset_np): diff --git a/test/gridder_imager_test.py b/test/gridder_imager_test.py index e36f98c9..1590028a 100644 --- a/test/gridder_imager_test.py +++ b/test/gridder_imager_test.py @@ -4,9 +4,7 @@ import matplotlib.pyplot as plt import numpy as np import pytest - from mpol import coordinates, gridding -from mpol.constants import * # cache an instantiated imager for future imaging ops @@ -141,7 +139,7 @@ def test_grid_uniform(imager, tmp_path): ax[1, 0].imshow(img_uniform[chan], **kw) ax[0, 1].imshow(beam_robust[chan], **kw) - ax[0, 1].set_title("robust={:}".format(r)) + ax[0, 1].set_title(f"robust={r}") ax[1, 1].imshow(img_robust[chan], **kw) # the differences @@ -183,7 +181,7 @@ def test_grid_uniform_arcsec2(imager, tmp_path): plt.colorbar(im, ax=ax[1, 0]) ax[0, 1].imshow(beam_robust[chan], **kw) - ax[0, 1].set_title("robust={:}".format(r)) + ax[0, 1].set_title(f"robust={r}") im = ax[1, 1].imshow(img_robust[chan], **kw) plt.colorbar(im, ax=ax[1, 1]) @@ -225,7 +223,7 @@ def test_grid_natural(imager, tmp_path): ax[1, 0].imshow(img_natural[chan], **kw) ax[0, 1].imshow(beam_robust[chan], **kw) - ax[0, 1].set_title("robust={:}".format(r)) + ax[0, 1].set_title(f"robust={r}") ax[1, 1].imshow(img_robust[chan], **kw) # the differences @@ -267,7 +265,7 @@ def test_grid_natural_arcsec2(imager, tmp_path): plt.colorbar(im, ax=ax[1, 0]) ax[0, 1].imshow(beam_robust[chan], **kw) - ax[0, 1].set_title("robust={:}".format(r)) + ax[0, 1].set_title(f"robust={r}") im = ax[1, 1].imshow(img_robust[chan], **kw) plt.colorbar(im, ax=ax[1, 1]) diff --git a/test/gridder_init_test.py b/test/gridder_init_test.py index 330f6a99..d6639c8f 100644 --- a/test/gridder_init_test.py +++ b/test/gridder_init_test.py @@ -1,7 +1,6 @@ +import numpy as np import pytest - from mpol import coordinates, gridding -from mpol.constants import * from mpol.exceptions import CellSizeError, DataError diff --git a/test/images_test.py b/test/images_test.py index b365c763..b60b14ec 100644 --- a/test/images_test.py +++ b/test/images_test.py @@ -1,10 +1,10 @@ import matplotlib.pyplot as plt +import numpy as np import pytest import torch from astropy.io import fits +from mpol import coordinates, images, plot, utils -from mpol import coordinates, images, utils -from mpol.constants import * def test_single_chan(): coords = coordinates.GridCoords(cell_size=0.015, npix=800) @@ -56,11 +56,9 @@ def test_basecube_imagecube(coords, tmp_path): # create a mock cube that includes negative values nchan = 1 mean = torch.full( - (nchan, coords.npix, coords.npix), fill_value=-0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=-0.5) std = torch.full( - (nchan, coords.npix, coords.npix), fill_value=0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=0.5) # tensor base_cube = torch.normal(mean=mean, std=std) @@ -111,11 +109,9 @@ def test_base_cube_conv_cube(coords, tmp_path): # create a mock cube that includes negative values nchan = 1 mean = torch.full( - (nchan, coords.npix, coords.npix), fill_value=-0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=-0.5) std = torch.full( - (nchan, coords.npix, coords.npix), fill_value=0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=0.5) # The HannConvCube expects to function on a pre-packed ImageCube, # so in order to get the plots looking correct on this test image, @@ -156,11 +152,9 @@ def test_multi_chan_conv(coords, tmp_path): nchan = 10 mean = torch.full( - (nchan, coords.npix, coords.npix), fill_value=-0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=-0.5) std = torch.full( - (nchan, coords.npix, coords.npix), fill_value=0.5, dtype=torch.double - ) + (nchan, coords.npix, coords.npix), fill_value=0.5) # tensor test_cube = torch.normal(mean=mean, std=std) @@ -177,3 +171,236 @@ def test_image_flux(coords): im = images.ImageCube(coords=coords, nchan=nchan) im(bcube()) assert im.flux.size()[0] == nchan + + +def test_plot_test_img(packed_cube, coords, tmp_path): + # show only the first channel + chan = 0 + fig, ax = plt.subplots(nrows=1) + + # put back to sky + sky_cube = utils.packed_cube_to_sky_cube(packed_cube) + im = ax.imshow( + sky_cube[chan], extent=coords.img_ext, origin="lower", cmap="inferno" + ) + plt.colorbar(im) + fig.savefig(tmp_path / "sky_cube.png", dpi=300) + + plt.close("all") + +def test_taper(coords, tmp_path): + for r in np.arange(0.0, 0.2, step=0.02): + fig, ax = plt.subplots(ncols=1) + + taper_2D = images.uv_gaussian_taper(coords, r, r, 0.0) + print(type(taper_2D)) + + norm = plot.get_image_cmap_norm(taper_2D, symmetric=True) + im = ax.imshow( + taper_2D, + extent=coords.vis_ext_Mlam, + origin="lower", + cmap="bwr_r", + norm=norm, + ) + plt.colorbar(im, ax=ax) + + fig.savefig(tmp_path / f"taper{r:.2f}.png", dpi=300) + + plt.close("all") + +def test_gaussian_kernel(coords, tmp_path): + rs = np.array([0.02, 0.06, 0.10]) + nchan = 3 + fig, ax = plt.subplots(nrows=len(rs), ncols=nchan, figsize=(10,10)) + for i,r in enumerate(rs): + layer = images.GaussConvImage(coords, nchan=nchan, FWHM_maj=r, FWHM_min=0.5 * r) + weight = layer.m.weight.detach().numpy() + for j in range(nchan): + im = ax[i,j].imshow(weight[j,0], interpolation="none", origin="lower") + plt.colorbar(im, ax=ax[i,j]) + + fig.savefig(tmp_path / "filter.png", dpi=300) + plt.close("all") + +def test_gaussian_kernel_rotate(coords, tmp_path): + r = 0.04 + Omegas = [0, 20, 40] # degrees + nchan = 3 + fig, ax = plt.subplots(nrows=len(Omegas), ncols=nchan, figsize=(10, 10)) + for i, Omega in enumerate(Omegas): + layer = images.GaussConvImage(coords, nchan=nchan, FWHM_maj=r, FWHM_min=0.5 * r, Omega=Omega) + weight = layer.m.weight.detach().numpy() + for j in range(nchan): + im = ax[i, j].imshow(weight[j, 0], interpolation="none",origin="lower") + plt.colorbar(im, ax=ax[i, j]) + + fig.savefig(tmp_path / "filter.png", dpi=300) + plt.close("all") + + +def test_GaussConvImage(sky_cube, coords, tmp_path): + # show only the first channel + chan = 0 + nchan = sky_cube.size()[0] + + for r in np.arange(0.02, 0.11, step=0.04): + + layer = images.GaussConvImage(coords, nchan=nchan, FWHM_maj=r, FWHM_min=r) + + print("Kernel size", layer.m.weight.size()) + + fig, ax = plt.subplots(ncols=2) + + im = ax[0].imshow( + sky_cube[chan], extent=coords.img_ext, origin="lower", cmap="inferno" + ) + flux = coords.cell_size**2 * torch.sum(sky_cube[chan]) + ax[0].set_title(f"tot flux: {flux:.3f} Jy") + plt.colorbar(im, ax=ax[0]) + + c_sky = layer(sky_cube) + im = ax[1].imshow( + c_sky[chan], extent=coords.img_ext, origin="lower", cmap="inferno" + ) + flux = coords.cell_size**2 * torch.sum(c_sky[chan]) + ax[1].set_title(f"tot flux: {flux:.3f} Jy") + + plt.colorbar(im, ax=ax[1]) + fig.savefig(tmp_path / f"convolved_{r:.2f}.png", dpi=300) + + plt.close("all") + +def test_GaussConvImage_rotate(sky_cube, coords, tmp_path): + # show only the first channel + chan = 0 + nchan = sky_cube.size()[0] + + for Omega in [0, 20, 40]: + layer = images.GaussConvImage(coords, nchan=nchan, FWHM_maj=0.16, FWHM_min=0.06, Omega=Omega) + + fig, ax = plt.subplots(ncols=2) + + im = ax[0].imshow( + sky_cube[chan], extent=coords.img_ext, origin="lower", cmap="inferno" + ) + flux = coords.cell_size**2 * torch.sum(sky_cube[chan]) + ax[0].set_title(f"tot flux: {flux:.3f} Jy") + plt.colorbar(im, ax=ax[0]) + + c_sky = layer(sky_cube) + im = ax[1].imshow( + c_sky[chan], extent=coords.img_ext, origin="lower", cmap="inferno" + ) + flux = coords.cell_size**2 * torch.sum(c_sky[chan]) + ax[1].set_title(f"tot flux: {flux:.3f} Jy") + + plt.colorbar(im, ax=ax[1]) + fig.savefig(tmp_path / f"convolved_{Omega:.2f}.png", dpi=300) + + plt.close("all") + +def test_GaussFourier(packed_cube, coords, tmp_path): + chan = 0 + + for FWHM in np.linspace(0.02, 0.5, num=10): + fig, ax = plt.subplots(ncols=2) + # put back to sky + sky_cube = utils.packed_cube_to_sky_cube(packed_cube) + im = ax[0].imshow( + sky_cube[chan], extent=coords.img_ext, origin="lower", cmap="inferno" + ) + flux = coords.cell_size**2 * torch.sum(sky_cube[chan]) + ax[0].set_title(f"tot flux: {flux:.3f} Jy") + plt.colorbar(im, ax=ax[0]) + + # set base resolution + layer = images.GaussConvFourier(coords, FWHM, FWHM) + c = layer(packed_cube) + # put back to sky + c_sky = utils.packed_cube_to_sky_cube(c) + flux = coords.cell_size**2 * torch.sum(c_sky[chan]) + im = ax[1].imshow( + c_sky[chan].detach().numpy(), + extent=coords.img_ext, + origin="lower", + cmap="inferno", + ) + ax[1].set_title(f"tot flux: {flux:.3f} Jy") + + plt.colorbar(im, ax=ax[1]) + fig.savefig(tmp_path / "convolved_FWHM_{:.2f}.png".format(FWHM), dpi=300) + + plt.close("all") + +def test_GaussFourier_rotate(packed_cube, coords, tmp_path): + chan = 0 + + sky_cube = utils.packed_cube_to_sky_cube(packed_cube) + + for Omega in [0, 20, 40]: + layer = images.GaussConvFourier( + coords, FWHM_maj=0.16, FWHM_min=0.06, Omega=Omega + ) + + fig, ax = plt.subplots(ncols=2) + + im = ax[0].imshow( + sky_cube[chan], extent=coords.img_ext, origin="lower", cmap="inferno" + ) + flux = coords.cell_size**2 * torch.sum(sky_cube[chan]) + ax[0].set_title(f"tot flux: {flux:.3f} Jy") + plt.colorbar(im, ax=ax[0]) + + c_sky = layer(sky_cube) + im = ax[1].imshow( + c_sky[chan], extent=coords.img_ext, origin="lower", cmap="inferno" + ) + flux = coords.cell_size**2 * torch.sum(c_sky[chan]) + ax[1].set_title(f"tot flux: {flux:.3f} Jy") + + plt.colorbar(im, ax=ax[1]) + fig.savefig(tmp_path / f"convolved_{Omega:.2f}.png", dpi=300) + + plt.close("all") + + +def test_GaussFourier_point(coords, tmp_path): + FWHM = 0.5 + + # create an image with a point source in the center + sky_cube = torch.zeros((1, coords.npix, coords.npix)) + cpix = coords.npix//2 + sky_cube[0,cpix,cpix] = 1.0 + + fig, ax = plt.subplots(ncols=2, sharex=True, sharey=True) + # put back to sky + im = ax[0].imshow( + sky_cube[0], extent=coords.img_ext, origin="lower", cmap="inferno" + ) + flux = coords.cell_size**2 * torch.sum(sky_cube[0]) + ax[0].set_title(f"tot flux: {flux:.3f} Jy") + plt.colorbar(im, ax=ax[0]) + + # set base resolution + layer = images.GaussConvFourier(coords, FWHM, FWHM) + packed_cube = utils.sky_cube_to_packed_cube(sky_cube) + c = layer(packed_cube) + # put back to sky + c_sky = utils.packed_cube_to_sky_cube(c) + flux = coords.cell_size**2 * torch.sum(c_sky[0]) + im = ax[1].imshow( + c_sky[0].detach().numpy(), + extent=coords.img_ext, + origin="lower", + cmap="inferno", + ) + ax[1].set_title(f"tot flux: {flux:.3f} Jy") + r = 0.7 + ax[1].set_xlim(r, -r) + ax[1].set_ylim(-r, r) + + plt.colorbar(im, ax=ax[1]) + fig.savefig(tmp_path / "point_source_FWHM_{:.2f}.png".format(FWHM), dpi=300) + + plt.close("all") diff --git a/test/input_output_test.py b/test/input_output_test.py index 39d1a4a6..31821534 100644 --- a/test/input_output_test.py +++ b/test/input_output_test.py @@ -1,9 +1,8 @@ -import numpy as np from astropy.utils.data import download_file - from mpol.input_output import ProcessFitsImage + def test_ProcessFitsImage(): # get a .fits file produced with casa fname = download_file( @@ -12,6 +11,6 @@ def test_ProcessFitsImage(): show_progress=True, pkgname="mpol", ) - + fits_image = ProcessFitsImage(fname) - clean_im, clean_im_ext, clean_beam = fits_image.get_image(beam=True) \ No newline at end of file + clean_im, clean_im_ext, clean_beam = fits_image.get_image(beam=True) diff --git a/test/losses_test.py b/test/losses_test.py index 8072d3e4..3ffcbc1b 100644 --- a/test/losses_test.py +++ b/test/losses_test.py @@ -1,7 +1,6 @@ import numpy as np import pytest import torch - from mpol import coordinates, fourier, losses diff --git a/test/onedim_test.py b/test/onedim_test.py index 04404e46..da395c66 100644 --- a/test/onedim_test.py +++ b/test/onedim_test.py @@ -1,12 +1,12 @@ import matplotlib.pyplot as plt import numpy as np - from mpol.onedim import radialI, radialV from mpol.plot import plot_image from mpol.utils import torch2npy + def test_radialI(mock_1d_image_model, tmp_path): - # obtain a 1d radial brightness profile I(r) from an image + # obtain a 1d radial brightness profile I(r) from an image rtrue, itrue, icube, _, _, geom = mock_1d_image_model @@ -16,45 +16,45 @@ def test_radialI(mock_1d_image_model, tmp_path): fig, ax = plt.subplots(ncols=2, figsize=(10,5)) - plot_image(np.squeeze(torch2npy(icube.sky_cube)), extent=icube.coords.img_ext, - ax=ax[0], clab='Jy / sr') + plot_image(np.squeeze(torch2npy(icube.sky_cube)), extent=icube.coords.img_ext, + ax=ax[0], clab="Jy / sr") + + ax[1].plot(rtrue, itrue, "k", label="truth") + ax[1].plot(rtest, itest, "r.-", label="recovery") - ax[1].plot(rtrue, itrue, 'k', label='truth') - ax[1].plot(rtest, itest, 'r.-', label='recovery') - ax[0].set_title(f"Geometry:\n{geom}", fontsize=7) - - ax[1].set_xlabel('r [arcsec]') - ax[1].set_ylabel('I [Jy / sr]') + + ax[1].set_xlabel("r [arcsec]") + ax[1].set_ylabel("I [Jy / sr]") ax[1].legend() fig.savefig(tmp_path / "test_radialI.png", dpi=300) plt.close("all") expected = [ - 6.40747314e+10, 4.01920507e+10, 1.44803534e+10, 2.94238627e+09, - 1.28782935e+10, 2.68613199e+10, 2.26564596e+10, 1.81151845e+10, + 6.40747314e+10, 4.01920507e+10, 1.44803534e+10, 2.94238627e+09, + 1.28782935e+10, 2.68613199e+10, 2.26564596e+10, 1.81151845e+10, 1.52128965e+10, 1.05640352e+10, 1.33411204e+10, 1.61124502e+10, - 1.41500539e+10, 1.20121195e+10, 1.11770326e+10, 1.19676913e+10, - 1.20941686e+10, 1.09498286e+10, 9.74236410e+09, 7.99589196e+09, + 1.41500539e+10, 1.20121195e+10, 1.11770326e+10, 1.19676913e+10, + 1.20941686e+10, 1.09498286e+10, 9.74236410e+09, 7.99589196e+09, 5.94787809e+09, 3.82074946e+09, 1.80823933e+09, 4.48414819e+08, - 3.17808840e+08, 5.77317876e+08, 3.98851281e+08, 8.06459834e+08, - 2.88706161e+09, 6.09577814e+09, 6.98556762e+09, 4.47436415e+09, + 3.17808840e+08, 5.77317876e+08, 3.98851281e+08, 8.06459834e+08, + 2.88706161e+09, 6.09577814e+09, 6.98556762e+09, 4.47436415e+09, 1.89511273e+09, 5.96604356e+08, 3.44571640e+08, 5.65906765e+08, - 2.85854589e+08, 2.67589013e+08, 3.98357054e+08, 2.97052261e+08, - 3.82744591e+08, 3.52239791e+08, 2.74336969e+08, 2.28425747e+08, + 2.85854589e+08, 2.67589013e+08, 3.98357054e+08, 2.97052261e+08, + 3.82744591e+08, 3.52239791e+08, 2.74336969e+08, 2.28425747e+08, 1.82290043e+08, 3.16077299e+08, 1.18465538e+09, 3.32239287e+09, - 5.26718846e+09, 5.16458748e+09, 3.58114198e+09, 2.13431954e+09, - 1.40936556e+09, 1.04032244e+09, 9.24050422e+08, 8.46829316e+08, + 5.26718846e+09, 5.16458748e+09, 3.58114198e+09, 2.13431954e+09, + 1.40936556e+09, 1.04032244e+09, 9.24050422e+08, 8.46829316e+08, 6.80909295e+08, 6.83812465e+08, 6.91856237e+08, 5.29227136e+08, - 3.97557293e+08, 3.54893419e+08, 2.60997039e+08, 2.09306498e+08, - 1.93930693e+08, 6.97032407e+07, 6.66090083e+07, 1.40079594e+08, + 3.97557293e+08, 3.54893419e+08, 2.60997039e+08, 2.09306498e+08, + 1.93930693e+08, 6.97032407e+07, 6.66090083e+07, 1.40079594e+08, 7.21775931e+07, 3.23902663e+07, 3.35932300e+07, 7.63318789e+06, - 1.29740981e+07, 1.44300351e+07, 8.06249624e+06, 5.85567843e+06, - 1.42637174e+06, 3.21445075e+06, 1.83763663e+06, 1.16926652e+07, + 1.29740981e+07, 1.44300351e+07, 8.06249624e+06, 5.85567843e+06, + 1.42637174e+06, 3.21445075e+06, 1.83763663e+06, 1.16926652e+07, 2.46918188e+07, 1.60206523e+07, 3.26596592e+06, 1.27837319e+05, - 2.27104612e+04, 4.77267063e+03, 2.90467640e+03, 2.88482230e+03, - 1.43402521e+03, 1.54791996e+03, 7.23397046e+02, 1.02561351e+03, + 2.27104612e+04, 4.77267063e+03, 2.90467640e+03, 2.88482230e+03, + 1.43402521e+03, 1.54791996e+03, 7.23397046e+02, 1.02561351e+03, 5.24845888e+02, 1.47320552e+03, 7.40419174e+02, 4.59029378e-03, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00 ] @@ -74,17 +74,17 @@ def test_radialV(mock_1d_vis_model, tmp_path): fig, ax = plt.subplots(ncols=1, nrows=2, figsize=(10,10)) - ax[0].plot(q_dep / 1e6, Vtrue_dep.real, 'k.', label='truth deprojected') - ax[0].plot(qtest / 1e3, Vtest.real, 'r.-', label='recovery') + ax[0].plot(q_dep / 1e6, Vtrue_dep.real, "k.", label="truth deprojected") + ax[0].plot(qtest / 1e3, Vtest.real, "r.-", label="recovery") - ax[1].plot(q_dep / 1e6, Vtrue_dep.imag, 'k.') - ax[1].plot(qtest / 1e3, Vtest.imag, 'r.') + ax[1].plot(q_dep / 1e6, Vtrue_dep.imag, "k.") + ax[1].plot(qtest / 1e3, Vtest.imag, "r.") ax[0].set_xlim(-0.5, 6) ax[1].set_xlim(-0.5, 6) - ax[1].set_xlabel(r'Baseline [M$\lambda$]') - ax[0].set_ylabel('Re(V) [Jy]') - ax[1].set_ylabel('Im(V) [Jy]') + ax[1].set_xlabel(r"Baseline [M$\lambda$]") + ax[0].set_ylabel("Re(V) [Jy]") + ax[1].set_ylabel("Im(V) [Jy]") ax[0].set_title(f"Geometry {geom}", fontsize=10) ax[0].legend() @@ -120,4 +120,4 @@ def test_radialV(mock_1d_vis_model, tmp_path): ] np.testing.assert_allclose(Vtest.real, expected, rtol=1e-6, - err_msg="test_radialV") \ No newline at end of file + err_msg="test_radialV") diff --git a/test/plot_test.py b/test/plot_test.py index 4b01ed1b..736df7ed 100644 --- a/test/plot_test.py +++ b/test/plot_test.py @@ -1,8 +1,5 @@ -import numpy as np -from astropy.utils.data import download_file -from mpol import precomposed # from mpol.plot import image_comparison_fig # def test_image_comparison_fig(coords, tmp_path): @@ -23,10 +20,10 @@ # pkgname="mpol", # ) -# image_comparison_fig(model, u, v, V, weights, robust=0.5, +# image_comparison_fig(model, u, v, V, weights, robust=0.5, # clean_fits=fname, -# share_cscale=False, +# share_cscale=False, # xzoom=[-2, 2], yzoom=[-2, 2], # title="test", -# save_prefix=None, +# save_prefix=None, # ) diff --git a/test/train_test_test.py b/test/train_test_test.py index dd519d94..26a0fd9e 100644 --- a/test/train_test_test.py +++ b/test/train_test_test.py @@ -2,13 +2,12 @@ import numpy as np import torch import torch.optim -from torch.utils.tensorboard import SummaryWriter - from mpol import losses, precomposed + # from mpol.plot import train_diagnostics_fig # from mpol.training import TrainTest, train_to_dirty_image from mpol.utils import torch2npy - +from torch.utils.tensorboard import SummaryWriter # def test_traintestclass_training(coords, imager, dataset, generic_parameters): # # using the TrainTest class, run a training loop without regularizers @@ -16,7 +15,7 @@ # model = precomposed.GriddedNet(coords=coords, nchan=nchan) # train_pars = generic_parameters["train_pars"] - + # # no regularizers # train_pars["regularizers"] = {} @@ -29,7 +28,7 @@ # def test_traintestclass_training_scheduler(coords, imager, dataset, generic_parameters): -# # using the TrainTest class, run a training loop with regularizers, +# # using the TrainTest class, run a training loop with regularizers, # # using the learning rate scheduler # nchan = dataset.nchan # model = precomposed.GriddedNet(coords=coords, nchan=nchan) @@ -54,11 +53,11 @@ # nchan = dataset.nchan # model = precomposed.GriddedNet(coords=coords, nchan=nchan) -# train_pars = generic_parameters["train_pars"] +# train_pars = generic_parameters["train_pars"] # learn_rate = generic_parameters["crossval_pars"]["learn_rate"] -# train_pars['regularizers']['entropy']['guess'] = True +# train_pars['regularizers']['entropy']['guess'] = True # optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate) @@ -67,8 +66,8 @@ # def test_traintestclass_train_diagnostics_fig(coords, imager, dataset, generic_parameters, tmp_path): -# # using the TrainTest class, run a training loop, -# # and generate the train diagnostics figure +# # using the TrainTest class, run a training loop, +# # and generate the train diagnostics figure # nchan = dataset.nchan # model = precomposed.GriddedNet(coords=coords, nchan=nchan) @@ -87,8 +86,8 @@ # old_mod_im = torch2npy(model.icube.sky_cube[0]) -# train_fig, train_axes = train_diagnostics_fig(model, -# losses=loss_history, +# train_fig, train_axes = train_diagnostics_fig(model, +# losses=loss_history, # learn_rates=learn_rates, # fluxes=np.zeros(len(loss_history)), # old_model_image=old_mod_im @@ -111,7 +110,7 @@ def test_standalone_init_train(coords, dataset): - # not using TrainTest class, + # not using TrainTest class, # configure a class to train with and test that it initializes nchan = dataset.nchan @@ -131,7 +130,7 @@ def test_standalone_init_train(coords, dataset): def test_standalone_train_loop(coords, dataset_cont, tmp_path): - # not using TrainTest class, + # not using TrainTest class, # set everything up to run on a single channel # and run a few iterations @@ -176,7 +175,7 @@ def test_standalone_train_loop(coords, dataset_cont, tmp_path): def test_tensorboard(coords, dataset_cont): - # not using TrainTest class, + # not using TrainTest class, # set everything up to run on a single channel and then # test the writer function diff --git a/test/utils_test.py b/test/utils_test.py index d4f17b84..5fe41ad1 100644 --- a/test/utils_test.py +++ b/test/utils_test.py @@ -1,9 +1,7 @@ import matplotlib.pyplot as plt import numpy as np import pytest - from mpol import coordinates, utils -from mpol.constants import * @pytest.fixture @@ -140,4 +138,4 @@ def test_get_optimal_image_properties(baselines_1D): max_data_freq = max(abs(u).max(), abs(v).max()) - assert(utils.get_max_spatial_freq(cell_size, npix) >= max_data_freq) \ No newline at end of file + assert(utils.get_max_spatial_freq(cell_size, npix) >= max_data_freq)