Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Roman Prism and Grism #416

Merged
merged 15 commits into from
Jul 15, 2021
Merged
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ install:

before_script:
# Get WebbPSF data files (just a subset of the full 250 MB!) and set up environment variable
- wget https://stsci.box.com/shared/static/qcptcokkbx7fgi3c00w2732yezkxzb99.gz -O /tmp/minimal-webbpsf-data.tar.gz
- wget https://stsci.box.com/shared/static/ci3vkozwgyj82f1qle986k1hmeoggzzh.gz -O /tmp/minimal-webbpsf-data.tar.gz
- tar -xzvf /tmp/minimal-webbpsf-data.tar.gz
- export WEBBPSF_PATH="${TRAVIS_BUILD_DIR}/webbpsf-data"

Expand Down
218 changes: 195 additions & 23 deletions webbpsf/roman.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@
import os.path
import poppy
import numpy as np
from . import webbpsf_core
from scipy.interpolate import griddata

from scipy.interpolate import griddata, RegularGridInterpolator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RegularGridInterpolator is not needed anymore now that we've simplified the handling of the extrapolation case.

from astropy.io import fits
import astropy.units as u
import logging

from . import webbpsf_core
from .optics import _fix_zgrid_NaNs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_fix_zgrid_NaNs is not needed anymore now that we've simplified the handling of the extrapolation case.



_log = logging.getLogger('webbpsf')
import pprint

GRISM_FILTER = 'G150'
PRISM_FILTER = 'P120'

class WavelengthDependenceInterpolator(object):
"""WavelengthDependenceInterpolator can be configured with
Expand Down Expand Up @@ -158,16 +164,46 @@ def get_aberration_terms(self, wavelength):
assert len(aberration_array.shape) == 2, "computed aberration array is not 2D " \
"(inconsistent number of Zernike terms " \
"at each point?)"

field_position = tuple(np.clip(self.field_position, 4, 4092))
field_position = tuple(self.field_position)
coefficients = griddata(
np.asarray(field_points),
np.asarray(aberration_terms),
field_position,
method='linear'
)
if np.any(np.isnan(coefficients)):
raise RuntimeError("Could not get aberrations for input field point")
# FIND TWO CLOSEST INPUT GRID POINTS:
dist = []
corners = field_points[1:] # use only the corner points
for i, ip in enumerate(corners):
dist.append(np.sqrt(((ip[0] - field_position[0]) ** 2) + ((ip[1] - field_position[1]) ** 2)))
min_dist_indx = np.argsort(dist)[:2] # keep two closest points
# DEFINE LINE B/W TWO POINTS, FIND ORTHOGONAL LINE AT POINT OF INTEREST,
# AND FIND INTERSECTION OF THESE TWO LINES.
x1, y1 = corners[min_dist_indx[0]]
x2, y2 = corners[min_dist_indx[1]]
dx = x2 - x1
dy = y2 - y1
a = (dy * (field_position[1] - y1) + dx * (field_position[0] - x1)) / (dx * dx + dy * dy)
closest_interp_point = (x1 + a * dx, y1 + a * dy)
# INTERPOLATE ABERRATIONS TO CLOSEST INTERPOLATED POINT:
coefficients = griddata(
np.asarray(field_points),
np.asarray(aberration_terms),
closest_interp_point,
method='linear')
# IF CLOSEST INTERPOLATED POINT IS STILL OUTSIDE THE INPUT GRID,
# THEN USE NEAREST GRID POINT INSTEAD:
if np.any(np.isnan(coefficients)):
coefficients = aberration_terms[min_dist_indx[0] + 1]
_log.warn("Attempted to get aberrations at field point {} which is outside the range "
"of the reference data; approximating to nearest input grid point".format(field_position))
else:
_log.warn("Attempted to get aberrations at field point {} which is outside the range "
"of the reference data; approximating to nearest interpolated point {}".format(
field_position, closest_interp_point))
assert not np.any(np.isnan(coefficients)), "Could not compute aberration " \
"at field point {}".format(field_position)
if self._omit_piston_tip_tilt:
_log.debug("Omitting piston/tip/tilt")
coefficients[:3] = 0.0 # omit piston, tip, and tilt Zernikes
Expand Down Expand Up @@ -302,7 +338,7 @@ def __init__(self):
self._masked_pupil_path = None

# List of filters that need the masked pupil
self._masked_filters = ['F184']
self._masked_filters = ['F184', GRISM_FILTER]

# Flag to en-/disable automatic selection of the appropriate pupil_mask
self.auto_pupil = True
Expand Down Expand Up @@ -459,12 +495,6 @@ class WFI(RomanInstrument):
def __init__(self):
"""
Initiate WFI

Parameters
-----------
set_pupil_mask_on : bool or None
Set to True or False to force using or not using the cold pupil mask,
or to None for the automatic behavior.
"""
# pixel scale is from Roman-AFTA SDT report final version (p. 91)
# https://roman.ipac.caltech.edu/sims/Param_db.html
Expand All @@ -473,22 +503,54 @@ def __init__(self):
# Initialize the pupil controller
self._pupil_controller = WFIPupilController()

# Initialize the aberrations for super().__init__
self._aberrations_files = {}
self._is_custom_aberrations = False
self._current_aberrations_file = ""

super(WFI, self).__init__("WFI", pixelscale=pixelscale)

self._pupil_controller.set_base_path(self._datapath)

self.pupil_mask_list = self._pupil_controller.pupil_mask_list

# Define defualt aberration files for WFI modes
self._aberrations_files = {
'imaging': os.path.join(self._datapath, 'wim_zernikes_cycle8.csv'),
'prism': os.path.join(self._datapath, 'wim_zernikes_cycle8_prism.csv'),
'grism': os.path.join(self._datapath, 'wim_zernikes_cycle8_grism.csv'),
'custom': None,
}

# Load default detector from aberration file
self._detector_npixels = 4096
self._detectors = _load_wfi_detector_aberrations(os.path.join(self._datapath, 'wim_zernikes_cycle8.csv'))
assert len(self._detectors.keys()) > 0
self._load_detector_aberrations(self._aberrations_files[self.mode])
self.detector = 'SCA01'

self.opd_list = [
os.path.join(self._WebbPSF_basepath, 'upscaled_HST_OPD.fits'),
]
self.pupilopd = self.opd_list[-1]

def _load_detector_aberrations(self, path):
"""
Helper function that, given a path to a file containing detector aberrations, loads the Zernike values and
populates the class' dictator list with `FieldDependentAberration` detectors. This function achieves this by
calling the `webbpsf.roman._load_wfi_detector_aberrations` function.

Users should use the `override_aberrations` function to override current aberrations.

Parameters
----------
path : string
Path to file containing detector aberrations
"""
detectors = _load_wfi_detector_aberrations(path)
assert len(detectors.keys()) > 0

self._detectors = detectors
self._current_aberrations_file = path

def _validate_config(self, **kwargs):
"""Validates that the WFI is configured sensibly

Expand Down Expand Up @@ -520,21 +582,13 @@ def pupil(self, value):
def pupil_mask(self):
return self._pupil_controller.pupil_mask

@RomanInstrument.filter.setter
def filter(self, value):
value = value.upper() # force to uppercase
if value not in self.filter_list:
raise ValueError("Instrument %s doesn't have a filter called %s." % (self.name, value))
self._filter = value
self._pupil_controller.validate_pupil(self.filter)

@pupil_mask.setter
def pupil_mask(self, name):
"""
Set the pupil mask

Parameters
------------
----------
name : string
Name of setting.
Settings:
Expand All @@ -558,6 +612,124 @@ def _unmasked_pupil_path(self):
def _masked_pupil_path(self):
return self._pupil_controller._masked_pupil_path

def _get_filter_mode(self, wfi_filter):
"""
Given a filter name, return the WFI mode

Parameters
----------
wfi_filter : string
Name of WFI filter

Returns
-------
mode : string
Returns 'imaging', 'grism' or 'prism' depending on filter.

Raises
------
ValueError
If the input filter is not found in the WFI filter list
"""

wfi_filter = wfi_filter.upper()
if wfi_filter == GRISM_FILTER:
return 'grism'
elif wfi_filter == PRISM_FILTER:
return 'prism'
elif wfi_filter in self.filter_list:
return 'imaging'
else:
raise ValueError("Instrument %s doesn't have a filter called %s." % (self.name, wfi_filter))

@property
def mode(self):
"""Current WFI mode"""
return self._get_filter_mode(self.filter)

@mode.setter
def mode(self, value):
"""Mode is set by changing filters"""
raise AttributeError("WFI mode cannot be directly specified; WFI mode is set by changing filters.")

def override_aberrations(self, aberrations_path):
"""
This function loads user provided aberrations from a file and locks this instrument
to only use the provided aberrations (even if the filter or mode change).
To release the lock and load the default aberrations, use the `reset_override_aberrations` function.
To load new user provided aberrations, simply call this function with the new path.

To load custom aberrations, please provide a csv file containing the detector names,
field point positions and Zernike values. The file should contain the following column names/values
(comments in parentheses should not be included):
- sca (Detector number)
- wavelength (µm)
- field_point (filed point number/id for SCA and wavelength, starts with 1)
- local_x (mm, local detector coords)
- local_y (mm, local detector coords)
- global_x (mm, global instrument coords)
- global_y (mm, global instrument coords)
- axis_local_angle_x (XAN)
- axis_local_angle_y (YAN)
- wfe_rms_waves (nm)
- wfe_pv_waves (waves)
- Z1 (Zernike phase NOLL coefficients)
- Z2 (Zernike phase NOLL coefficients)
- Z3 (Zernike phase NOLL coefficients)
- Z4 (Zernike phase NOLL coefficients)
.
.
.

Please refer to the default aberrations files for examples. If you have the WebbPSF data installed and defined,
you can get the path to that file by running the following:
>>> from webbpsf import roman
>>> wfi = roman.WFI()
>>> print(wfi._aberrations_files["imaging"])

Warning: You should not edit the default files!
"""
self._load_detector_aberrations(aberrations_path)
self._aberrations_files['custom'] = aberrations_path
self._is_custom_aberrations = True

def reset_override_aberrations(self):
"""Release detector aberrations override and load defaults"""
aberrations_path = self._aberrations_files[self.mode]
self._load_detector_aberrations(aberrations_path)
self._aberrations_files['custom'] = None
self._is_custom_aberrations = False

@RomanInstrument.filter.setter
def filter(self, value):

# Update Filter
# -------------
value = value.upper() # force to uppercase

if value not in self.filter_list:
raise ValueError("Instrument %s doesn't have a filter called %s." % (self.name, value))

self._filter = value

# Update Aberrations
# ------------------
# Check if _aberrations_files has been initiated (not empty) and if aberrations are locked by user
if self._aberrations_files and not self._is_custom_aberrations:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the check for if the aberrations files are empty? Under what circumstances would that be the case?

Copy link
Contributor Author

@robelgeda robelgeda Mar 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the class is initialized, with super(WFI, self).__init__("WFI", pixelscale=pixelscale), the filter is set to a new value in that function. We can not fill in self._aberrations_files before the super.__init__ call because the self._datapath attribute is not yet initialized. So we have to initialize self._aberrations_files to be empty and check for when super.__init__ sets the filter, so the code does not crash.

If we define self._datapath before super.__init__, this will not be needed. But I thought this was the safest way forward.


# Identify aberrations file for new mode
mode = self._get_filter_mode(self._filter)
aberrations_file = self._aberrations_files[mode]

# If aberrations are not already loaded for the new mode,
# load and replace detectors using the new mode's aberrations file.
if not os.path.samefile(self._current_aberrations_file, aberrations_file):
self._load_detector_aberrations(aberrations_file)

# Update Pupil
# ------------
self._pupil_controller.validate_pupil(self._filter)


class CGI(RomanInstrument):
"""
Expand Down
61 changes: 61 additions & 0 deletions webbpsf/tests/test_roman.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
import numpy as np
import pytest
from webbpsf import roman, measure_fwhm
from numpy import allclose


GRISM_FILTER = roman.GRISM_FILTER
PRISM_FILTER = roman.PRISM_FILTER
MASKED_FLAG = "FULL_MASK"
UNMASKED_FLAG = "RIM_MASK"
AUTO_FLAG = "AUTO"
Expand Down Expand Up @@ -190,8 +194,51 @@ def _test_filter_pupil(filter_name, expected_pupil):
_test_filter_pupil('F062', wfi._unmasked_pupil_path)
_test_filter_pupil('F158', wfi._unmasked_pupil_path)
_test_filter_pupil('F146', wfi._unmasked_pupil_path)
_test_filter_pupil(PRISM_FILTER, wfi._unmasked_pupil_path)

_test_filter_pupil('F184', wfi._masked_pupil_path)
_test_filter_pupil(GRISM_FILTER, wfi._masked_pupil_path)

def test_swapping_modes(wfi=None):

if wfi is None:
wfi = roman.WFI()

tests = [
# [filter, mode, pupil_file]
['F062', 'imaging', wfi._unmasked_pupil_path],
['F184', 'imaging', wfi._masked_pupil_path],
[PRISM_FILTER, 'prism', wfi._unmasked_pupil_path],
[GRISM_FILTER, 'grism', wfi._masked_pupil_path],
]

for test_filter, test_mode, test_pupil in tests:
wfi.filter = test_filter
assert wfi.filter == test_filter
assert wfi.mode == test_mode
assert wfi._current_aberrations_file == wfi._aberrations_files[test_mode]
assert wfi.pupil == test_pupil

def test_custom_aberrations():

wfi = roman.WFI()

# Use grism aberrations_file for testing
test_aberrations_file = wfi._aberrations_files['grism']

# Test override
# -------------
wfi.override_aberrations(test_aberrations_file)

for filter in wfi.filter_list:
wfi.filter = filter
assert wfi._current_aberrations_file == test_aberrations_file, "Filter change caused override to fail"

# Test Release Override
# ---------------------
wfi.reset_override_aberrations()
assert wfi._aberrations_files['custom'] is None, "Custom aberrations file not deleted on override release."
test_swapping_modes(wfi)

def test_WFI_limits_interpolation_range():
wfi = roman.WFI()
Expand Down Expand Up @@ -236,6 +283,20 @@ def test_WFI_limits_interpolation_range():
"Aberration outside wavelength range did not return closest value."
)

# Test border pixels that are outside of the ref data
# As of cycle 8 and 9, (4, 4) is the first pixel so we
# check if (0, 0) is approximated to (4, 4) via nearest point
# approximation:

det.field_position = (0, 0)
coefficients_outlier = det.get_aberration_terms(1e-6)

det.field_position = (4, 4)
coefficients_data = det.get_aberration_terms(1e-6)

assert np.allclose(coefficients_outlier, coefficients_data), "nearest point extrapolation " \
"failed for outlier field point"

def test_CGI_detector_position():
""" Test existence of the CGI detector position etc, and that you can't set it."""
cgi = roman.CGI()
Expand Down