Skip to content

Commit

Permalink
tests for PSF fitting methods
Browse files Browse the repository at this point in the history
  • Loading branch information
bmorris3 committed Jul 26, 2023
1 parent 4f739b6 commit f8baf48
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 22 deletions.
64 changes: 42 additions & 22 deletions romancal/lib/psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
PSFPhotometry,
SourceGrouper,
)
from webbpsf import gridded_library
from webbpsf import setup_logging as webbpsf_logging
from roman_datamodels.datamodels import ImageModel
from webbpsf import conf, gridded_library, restart_logging

from romancal.lib.dqflags import pixel as dq_flag_map
from romancal.lib.dqflags import pixel as roman_dq_flag_map

# set loggers to debug level by default:
log = logging.getLogger(__name__)
Expand All @@ -43,26 +43,29 @@


def create_gridded_psf_model(
path_prefix,
filt,
path,
detector,
oversample=12,
fov_pixels=12,
sqrt_n_psfs=4,
overwrite=False,
buffer_pixels=100,
instrument_options=None,
logging_level=None,
):
"""
Compute a gridded PSF model for one SCA via
`webbpsf.gridded_library.CreatePSFLibrary`.
Parameters
----------
path_prefix : str or Path-like
Prefix to the output file path for the gridded PSF model
FITS file. A suffix denoting the detector name will
be appended.
filt : str
Filter name, starting with "F". For example: `"F184"`.
path : str or Path-like
Output file path for the gridded PSF model FITS file
detector : str
Computed gridded PSF model for this SCA.
Examples include: `"SCA01"` or `"SCA18"`.
Expand All @@ -84,6 +87,9 @@ def create_gridded_psf_model(
For example, WebbPSF assumes Roman pointing jitter consistent with
mission specs by default, but this can be turned off with:
``{'jitter': None, 'jitter_sigma': 0}``.
logging_level : str, optional
Set logging level by name if not `None`, otherwise inherit from
the romancal logger.
Returns
-------
Expand All @@ -103,7 +109,7 @@ def create_gridded_psf_model(
n_psfs = int(sqrt_n_psfs) ** 2

# webbpsf appends "_sca??.fits" to the requested path:
expected_output_path = f"{path}_{detector.lower()}.fits"
expected_output_path = f"{path_prefix}_{detector.lower()}.fits"

# Choose pixel boundaries for the grid of PSFs:
start_pix = 0
Expand All @@ -118,10 +124,13 @@ def create_gridded_psf_model(
model_psf_centroids = [(int(x), int(y)) for y in pixel_range for x in pixel_range]

if not os.path.exists(expected_output_path) or overwrite:
webbpsf_logging(
if logging_level is None:
# pass along logging level from __name__'s logger to WebbPSF:
level=logging.getLevelName(log.level)
)
logging_level = logging.getLevelName(log.level)

# set the WebbPSF logging level (similar to webbpsf.utils.setup_logging):
conf.logging_level = logging_level
restart_logging(verbose=False)

wfi = webbpsf.roman.WFI()
wfi.filter = filt
Expand All @@ -145,7 +154,7 @@ def create_gridded_psf_model(
add_distortion=False,
crop_psf=False,
save=True,
filename=path,
filename=path_prefix,
overwrite=overwrite,
verbose=False,
)
Expand Down Expand Up @@ -189,7 +198,7 @@ def fit_psf_to_image_model(
y_init=None,
progress_bar=False,
error_lower_limit=None,
fit_shape=(16, 16),
fit_shape=(15, 15),
exclude_out_of_bounds=True,
):
"""
Expand Down Expand Up @@ -250,6 +259,8 @@ def fit_psf_to_image_model(
Returns
-------
results_table : `astropy.table.QTable`
PSF photometry results.
photometry : instance of class ``photutils_cls``
PSF photometry instance with configuration settings and results.
Expand Down Expand Up @@ -308,14 +319,17 @@ def fit_psf_to_image_model(

if dq is None:
if image_model is not None:
mask = dq_to_boolean_mask(image_model.dq)
mask = dq_to_boolean_mask(image_model)
else:
mask = None
else:
mask = dq_to_boolean_mask(dq)

if data is None and image_model is not None:
data = image_model.data

if error is None and image_model is not None:
error = image_model.error
error = image_model.err

if error_lower_limit is not None:
# option to enforce a lower limit on the flux uncertainties
Expand All @@ -333,20 +347,21 @@ def fit_psf_to_image_model(
guesses = guesses[init_centroid_in_range]

# fit the model PSF to the data:
photometry(data=data, error=error, init_params=guesses, mask=mask)
results_table = photometry(data=data, error=error, init_params=guesses, mask=mask)

# results are stored on the PSFPhotometry instance:
return photometry
return results_table, photometry


def dq_to_boolean_mask(image_model, ignore_flags=0, flag_map_name="roman_dq"):
def dq_to_boolean_mask(image_model_or_dq, ignore_flags=0, flag_map_name="ROMAN_DQ"):
"""
Convert a DQ bitmask to a boolean mask. Useful for photutils methods.
Parameters
----------
image_model : `roman_datamodels.datamodels.ImageModel`
ImageModel containing the DQ bitmask to convert to a boolean mask
image_model_or_dq : `roman_datamodels.datamodels.ImageModel` or `numpy.ndarray`
ImageModel containing the DQ bitmask to convert to a boolean mask,
or the DQ bitmask itself.
ignore_flags : int, str, list, None (default = 0)
See docs for `astropy.nddata.bitmask.extend_bit_flag_map`
flag_map_name : str
Expand All @@ -357,11 +372,16 @@ def dq_to_boolean_mask(image_model, ignore_flags=0, flag_map_name="roman_dq"):
mask : `numpy.ndarray`
Boolean mask
"""

if isinstance(image_model_or_dq, ImageModel):
dq = image_model_or_dq.dq

# add the Roman DQ flags to the astropy bitmask registry:
dq_flag_map = dict(roman_dq_flag_map)
dq_flag_map.pop("GOOD")

bitmask.extend_bit_flag_map(flag_map_name, **dq_flag_map)

# convert the bitmask to a boolean mask:
mask = bitmask.bitfield_to_boolean_mask(
image_model.dq, ignore_flags=ignore_flags, flag_map_name=flag_map_name
)
mask = bitmask.bitfield_to_boolean_mask(dq, ignore_flags=ignore_flags)
return mask
156 changes: 156 additions & 0 deletions romancal/lib/tests/test_psf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
Unit tests for the Roman source detection step code
"""

import os
import tempfile

import numpy as np
import pytest
from astropy import units as u
from astropy.nddata import overlap_slices
from photutils.psf import PSFPhotometry
from roman_datamodels import maker_utils as testutil
from roman_datamodels.datamodels import ImageModel

from romancal.lib.psf import create_gridded_psf_model, fit_psf_to_image_model

n_sources = 10
image_model_shape = (100, 100)
rng = np.random.default_rng(0)


@pytest.fixture
def setup_inputs():
def _setup(
nrows=image_model_shape[0], ncols=image_model_shape[1], noise=1.0, seed=None
):
"""
Return ImageModel of level 2 image.
"""
shape = (nrows, ncols)
wfi_image = testutil.mk_level2_image(shape=shape)
wfi_image.data = u.Quantity(
np.ones(shape, dtype=np.float32), u.electron / u.s, dtype=np.float32
)
wfi_image.meta.filename = "filename"

# add noise to data
if noise is not None:
rng = np.random.default_rng(seed or 19)
wfi_image.data = u.Quantity(
rng.normal(scale=noise, size=shape), u.electron / u.s, dtype=np.float32
)
wfi_image.err = noise * np.ones(shape, dtype=np.float32) * u.electron / u.s

# add dq array
wfi_image.dq = np.zeros(shape, dtype=np.uint32)

# construct ImageModel
mod = ImageModel(wfi_image)

return mod

return _setup


def add_synthetic_sources(
image_model,
psf_model,
true_x,
true_y,
true_amp,
oversample,
xname="x_0",
yname="y_0",
):
fit_models = []

# ensure truths are arrays:
true_x, true_y, true_amp = (
np.atleast_1d(truth) for truth in [true_x, true_y, true_amp]
)

for x, y, amp in zip(true_x, true_y, true_amp):
psf = psf_model.copy()
psf.parameters = [amp, x, y]
fit_models.append(psf)

synth_image = image_model.data
synth_err = image_model.err
psf_shape = np.array(psf_model.data.shape[1:]) // oversample

for fit_model in fit_models:
x0 = getattr(fit_model, xname).value
y0 = getattr(fit_model, yname).value
slc_lg, _ = overlap_slices(synth_image.shape, psf_shape, (y0, x0), mode="trim")
yy, xx = np.mgrid[slc_lg]
model_data = fit_model(xx, yy) * image_model.data.unit
model_err = np.sqrt(model_data.value) * model_data.unit
synth_image[slc_lg] += model_data
synth_err[slc_lg] = np.sqrt(synth_err[slc_lg] ** 2 + model_err**2)


@pytest.mark.parametrize(
"dx, dy, true_amp",
zip(
rng.uniform(-1, 1, n_sources),
rng.uniform(-1, 1, n_sources),
np.geomspace(10, 10_000, n_sources),
),
)
def test_psf_fit(setup_inputs, dx, dy, true_amp, seed=42):
# input parameters for PSF model:
filt = "F087"
detector = "SCA01"
oversample = 12
fov_pixels = 15

dir_path = tempfile.gettempdir()
filename_prefix = f"psf_model_{filt}"
file_path = os.path.join(dir_path, filename_prefix)

# compute gridded PSF model:
psf_model, centroids = create_gridded_psf_model(
file_path,
filt,
detector,
oversample=oversample,
fov_pixels=fov_pixels,
overwrite=False,
logging_level="ERROR",
)

# generate an ImageModel
image_model = setup_inputs(seed=seed)

# add synthetic sources to the ImageModel:

true_x = image_model_shape[0] / 2 + dx
true_y = image_model_shape[1] / 2 + dy
add_synthetic_sources(
image_model, psf_model, true_x, true_y, true_amp, oversample=oversample
)

if fov_pixels % 2 == 0:
fit_shape = (fov_pixels + 1, fov_pixels + 1)
else:
fit_shape = (fov_pixels, fov_pixels)

# fit the PSF to the ImageModel:
results_table, photometry = fit_psf_to_image_model(
image_model=image_model,
photometry_cls=PSFPhotometry,
psf_model=psf_model,
x_init=true_x,
y_init=true_y,
fit_shape=fit_shape,
)

# difference between input and output, normalized by the
# uncertainty. Has units of sigma:
delta_x = np.abs(true_x - results_table["x_fit"]) / results_table["x_err"]
delta_y = np.abs(true_x - results_table["x_fit"]) / results_table["x_err"]

assert np.all(delta_x < 3)
assert np.all(delta_y < 3)

0 comments on commit f8baf48

Please sign in to comment.