Skip to content

Commit

Permalink
Merge pull request #416 from robelgeda/roman
Browse files Browse the repository at this point in the history
Roman Prism and Grism
  • Loading branch information
mperrin authored Jul 15, 2021
2 parents 2ffcd67 + 476c6b9 commit 9673b52
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 23 deletions.
218 changes: 195 additions & 23 deletions webbpsf/roman.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@
import os.path
import poppy
import numpy as np
from . import webbpsf_core
from scipy.interpolate import griddata

from scipy.interpolate import griddata, RegularGridInterpolator
from astropy.io import fits
import astropy.units as u
import logging

from . import webbpsf_core
from .optics import _fix_zgrid_NaNs


_log = logging.getLogger('webbpsf')
import pprint

GRISM_FILTER = 'G150'
PRISM_FILTER = 'P120'

class WavelengthDependenceInterpolator(object):
"""WavelengthDependenceInterpolator can be configured with
Expand Down Expand Up @@ -158,16 +164,46 @@ def get_aberration_terms(self, wavelength):
assert len(aberration_array.shape) == 2, "computed aberration array is not 2D " \
"(inconsistent number of Zernike terms " \
"at each point?)"

field_position = tuple(np.clip(self.field_position, 4, 4092))
field_position = tuple(self.field_position)
coefficients = griddata(
np.asarray(field_points),
np.asarray(aberration_terms),
field_position,
method='linear'
)
if np.any(np.isnan(coefficients)):
raise RuntimeError("Could not get aberrations for input field point")
# FIND TWO CLOSEST INPUT GRID POINTS:
dist = []
corners = field_points[1:] # use only the corner points
for i, ip in enumerate(corners):
dist.append(np.sqrt(((ip[0] - field_position[0]) ** 2) + ((ip[1] - field_position[1]) ** 2)))
min_dist_indx = np.argsort(dist)[:2] # keep two closest points
# DEFINE LINE B/W TWO POINTS, FIND ORTHOGONAL LINE AT POINT OF INTEREST,
# AND FIND INTERSECTION OF THESE TWO LINES.
x1, y1 = corners[min_dist_indx[0]]
x2, y2 = corners[min_dist_indx[1]]
dx = x2 - x1
dy = y2 - y1
a = (dy * (field_position[1] - y1) + dx * (field_position[0] - x1)) / (dx * dx + dy * dy)
closest_interp_point = (x1 + a * dx, y1 + a * dy)
# INTERPOLATE ABERRATIONS TO CLOSEST INTERPOLATED POINT:
coefficients = griddata(
np.asarray(field_points),
np.asarray(aberration_terms),
closest_interp_point,
method='linear')
# IF CLOSEST INTERPOLATED POINT IS STILL OUTSIDE THE INPUT GRID,
# THEN USE NEAREST GRID POINT INSTEAD:
if np.any(np.isnan(coefficients)):
coefficients = aberration_terms[min_dist_indx[0] + 1]
_log.warn("Attempted to get aberrations at field point {} which is outside the range "
"of the reference data; approximating to nearest input grid point".format(field_position))
else:
_log.warn("Attempted to get aberrations at field point {} which is outside the range "
"of the reference data; approximating to nearest interpolated point {}".format(
field_position, closest_interp_point))
assert not np.any(np.isnan(coefficients)), "Could not compute aberration " \
"at field point {}".format(field_position)
if self._omit_piston_tip_tilt:
_log.debug("Omitting piston/tip/tilt")
coefficients[:3] = 0.0 # omit piston, tip, and tilt Zernikes
Expand Down Expand Up @@ -302,7 +338,7 @@ def __init__(self):
self._masked_pupil_path = None

# List of filters that need the masked pupil
self._masked_filters = ['F184']
self._masked_filters = ['F184', GRISM_FILTER]

# Flag to en-/disable automatic selection of the appropriate pupil_mask
self.auto_pupil = True
Expand Down Expand Up @@ -459,12 +495,6 @@ class WFI(RomanInstrument):
def __init__(self):
"""
Initiate WFI
Parameters
-----------
set_pupil_mask_on : bool or None
Set to True or False to force using or not using the cold pupil mask,
or to None for the automatic behavior.
"""
# pixel scale is from Roman-AFTA SDT report final version (p. 91)
# https://roman.ipac.caltech.edu/sims/Param_db.html
Expand All @@ -473,22 +503,54 @@ def __init__(self):
# Initialize the pupil controller
self._pupil_controller = WFIPupilController()

# Initialize the aberrations for super().__init__
self._aberrations_files = {}
self._is_custom_aberrations = False
self._current_aberrations_file = ""

super(WFI, self).__init__("WFI", pixelscale=pixelscale)

self._pupil_controller.set_base_path(self._datapath)

self.pupil_mask_list = self._pupil_controller.pupil_mask_list

# Define defualt aberration files for WFI modes
self._aberrations_files = {
'imaging': os.path.join(self._datapath, 'wim_zernikes_cycle8.csv'),
'prism': os.path.join(self._datapath, 'wim_zernikes_cycle8_prism.csv'),
'grism': os.path.join(self._datapath, 'wim_zernikes_cycle8_grism.csv'),
'custom': None,
}

# Load default detector from aberration file
self._detector_npixels = 4096
self._detectors = _load_wfi_detector_aberrations(os.path.join(self._datapath, 'wim_zernikes_cycle8.csv'))
assert len(self._detectors.keys()) > 0
self._load_detector_aberrations(self._aberrations_files[self.mode])
self.detector = 'SCA01'

self.opd_list = [
os.path.join(self._WebbPSF_basepath, 'upscaled_HST_OPD.fits'),
]
self.pupilopd = self.opd_list[-1]

def _load_detector_aberrations(self, path):
"""
Helper function that, given a path to a file containing detector aberrations, loads the Zernike values and
populates the class' dictator list with `FieldDependentAberration` detectors. This function achieves this by
calling the `webbpsf.roman._load_wfi_detector_aberrations` function.
Users should use the `override_aberrations` function to override current aberrations.
Parameters
----------
path : string
Path to file containing detector aberrations
"""
detectors = _load_wfi_detector_aberrations(path)
assert len(detectors.keys()) > 0

self._detectors = detectors
self._current_aberrations_file = path

def _validate_config(self, **kwargs):
"""Validates that the WFI is configured sensibly
Expand Down Expand Up @@ -520,21 +582,13 @@ def pupil(self, value):
def pupil_mask(self):
return self._pupil_controller.pupil_mask

@RomanInstrument.filter.setter
def filter(self, value):
value = value.upper() # force to uppercase
if value not in self.filter_list:
raise ValueError("Instrument %s doesn't have a filter called %s." % (self.name, value))
self._filter = value
self._pupil_controller.validate_pupil(self.filter)

@pupil_mask.setter
def pupil_mask(self, name):
"""
Set the pupil mask
Parameters
------------
----------
name : string
Name of setting.
Settings:
Expand All @@ -558,6 +612,124 @@ def _unmasked_pupil_path(self):
def _masked_pupil_path(self):
return self._pupil_controller._masked_pupil_path

def _get_filter_mode(self, wfi_filter):
"""
Given a filter name, return the WFI mode
Parameters
----------
wfi_filter : string
Name of WFI filter
Returns
-------
mode : string
Returns 'imaging', 'grism' or 'prism' depending on filter.
Raises
------
ValueError
If the input filter is not found in the WFI filter list
"""

wfi_filter = wfi_filter.upper()
if wfi_filter == GRISM_FILTER:
return 'grism'
elif wfi_filter == PRISM_FILTER:
return 'prism'
elif wfi_filter in self.filter_list:
return 'imaging'
else:
raise ValueError("Instrument %s doesn't have a filter called %s." % (self.name, wfi_filter))

@property
def mode(self):
"""Current WFI mode"""
return self._get_filter_mode(self.filter)

@mode.setter
def mode(self, value):
"""Mode is set by changing filters"""
raise AttributeError("WFI mode cannot be directly specified; WFI mode is set by changing filters.")

def override_aberrations(self, aberrations_path):
"""
This function loads user provided aberrations from a file and locks this instrument
to only use the provided aberrations (even if the filter or mode change).
To release the lock and load the default aberrations, use the `reset_override_aberrations` function.
To load new user provided aberrations, simply call this function with the new path.
To load custom aberrations, please provide a csv file containing the detector names,
field point positions and Zernike values. The file should contain the following column names/values
(comments in parentheses should not be included):
- sca (Detector number)
- wavelength (µm)
- field_point (filed point number/id for SCA and wavelength, starts with 1)
- local_x (mm, local detector coords)
- local_y (mm, local detector coords)
- global_x (mm, global instrument coords)
- global_y (mm, global instrument coords)
- axis_local_angle_x (XAN)
- axis_local_angle_y (YAN)
- wfe_rms_waves (nm)
- wfe_pv_waves (waves)
- Z1 (Zernike phase NOLL coefficients)
- Z2 (Zernike phase NOLL coefficients)
- Z3 (Zernike phase NOLL coefficients)
- Z4 (Zernike phase NOLL coefficients)
.
.
.
Please refer to the default aberrations files for examples. If you have the WebbPSF data installed and defined,
you can get the path to that file by running the following:
>>> from webbpsf import roman
>>> wfi = roman.WFI()
>>> print(wfi._aberrations_files["imaging"])
Warning: You should not edit the default files!
"""
self._load_detector_aberrations(aberrations_path)
self._aberrations_files['custom'] = aberrations_path
self._is_custom_aberrations = True

def reset_override_aberrations(self):
"""Release detector aberrations override and load defaults"""
aberrations_path = self._aberrations_files[self.mode]
self._load_detector_aberrations(aberrations_path)
self._aberrations_files['custom'] = None
self._is_custom_aberrations = False

@RomanInstrument.filter.setter
def filter(self, value):

# Update Filter
# -------------
value = value.upper() # force to uppercase

if value not in self.filter_list:
raise ValueError("Instrument %s doesn't have a filter called %s." % (self.name, value))

self._filter = value

# Update Aberrations
# ------------------
# Check if _aberrations_files has been initiated (not empty) and if aberrations are locked by user
if self._aberrations_files and not self._is_custom_aberrations:

# Identify aberrations file for new mode
mode = self._get_filter_mode(self._filter)
aberrations_file = self._aberrations_files[mode]

# If aberrations are not already loaded for the new mode,
# load and replace detectors using the new mode's aberrations file.
if not os.path.samefile(self._current_aberrations_file, aberrations_file):
self._load_detector_aberrations(aberrations_file)

# Update Pupil
# ------------
self._pupil_controller.validate_pupil(self._filter)


class CGI(RomanInstrument):
"""
Expand Down
61 changes: 61 additions & 0 deletions webbpsf/tests/test_roman.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
import numpy as np
import pytest
from webbpsf import roman, measure_fwhm
from numpy import allclose


GRISM_FILTER = roman.GRISM_FILTER
PRISM_FILTER = roman.PRISM_FILTER
MASKED_FLAG = "FULL_MASK"
UNMASKED_FLAG = "RIM_MASK"
AUTO_FLAG = "AUTO"
Expand Down Expand Up @@ -189,8 +193,51 @@ def _test_filter_pupil(filter_name, expected_pupil):
_test_filter_pupil('F062', wfi._unmasked_pupil_path)
_test_filter_pupil('F158', wfi._unmasked_pupil_path)
_test_filter_pupil('F146', wfi._unmasked_pupil_path)
_test_filter_pupil(PRISM_FILTER, wfi._unmasked_pupil_path)

_test_filter_pupil('F184', wfi._masked_pupil_path)
_test_filter_pupil(GRISM_FILTER, wfi._masked_pupil_path)

def test_swapping_modes(wfi=None):

if wfi is None:
wfi = roman.WFI()

tests = [
# [filter, mode, pupil_file]
['F062', 'imaging', wfi._unmasked_pupil_path],
['F184', 'imaging', wfi._masked_pupil_path],
[PRISM_FILTER, 'prism', wfi._unmasked_pupil_path],
[GRISM_FILTER, 'grism', wfi._masked_pupil_path],
]

for test_filter, test_mode, test_pupil in tests:
wfi.filter = test_filter
assert wfi.filter == test_filter
assert wfi.mode == test_mode
assert wfi._current_aberrations_file == wfi._aberrations_files[test_mode]
assert wfi.pupil == test_pupil

def test_custom_aberrations():

wfi = roman.WFI()

# Use grism aberrations_file for testing
test_aberrations_file = wfi._aberrations_files['grism']

# Test override
# -------------
wfi.override_aberrations(test_aberrations_file)

for filter in wfi.filter_list:
wfi.filter = filter
assert wfi._current_aberrations_file == test_aberrations_file, "Filter change caused override to fail"

# Test Release Override
# ---------------------
wfi.reset_override_aberrations()
assert wfi._aberrations_files['custom'] is None, "Custom aberrations file not deleted on override release."
test_swapping_modes(wfi)

def test_WFI_limits_interpolation_range():
wfi = roman.WFI()
Expand Down Expand Up @@ -235,6 +282,20 @@ def test_WFI_limits_interpolation_range():
"Aberration outside wavelength range did not return closest value."
)

# Test border pixels that are outside of the ref data
# As of cycle 8 and 9, (4, 4) is the first pixel so we
# check if (0, 0) is approximated to (4, 4) via nearest point
# approximation:

det.field_position = (0, 0)
coefficients_outlier = det.get_aberration_terms(1e-6)

det.field_position = (4, 4)
coefficients_data = det.get_aberration_terms(1e-6)

assert np.allclose(coefficients_outlier, coefficients_data), "nearest point extrapolation " \
"failed for outlier field point"

def test_CGI_detector_position():
""" Test existence of the CGI detector position etc, and that you can't set it."""
cgi = roman.CGI()
Expand Down

0 comments on commit 9673b52

Please sign in to comment.