diff --git a/changelog/285.bugfix.rst b/changelog/285.bugfix.rst new file mode 100644 index 00000000..8935d6a1 --- /dev/null +++ b/changelog/285.bugfix.rst @@ -0,0 +1 @@ +Fixed inverse transform in `.VaryingCelestialTransformSlit2D._map_transform`. Which fixes a bug in VISP WCSes. diff --git a/dkist/wcs/models.py b/dkist/wcs/models.py index 3d28935e..cdeb437f 100755 --- a/dkist/wcs/models.py +++ b/dkist/wcs/models.py @@ -205,7 +205,7 @@ def _map_transform(self, x, y, z, crpix, cdelt, lon_pole, inverse=False): # We need to broadcast the arrays together so they are all the same shape bx, by, bz = np.broadcast_arrays(x, y, z, subok=True) # Convert the z coordinate into an index to the lookup tables - ind = self.sanitize_index(bz) + zind = self.sanitize_index(bz) # Generate output arrays (ignore units for simplicity) if isinstance(bx, u.Quantity): @@ -217,12 +217,13 @@ def _map_transform(self, x, y, z, crpix, cdelt, lon_pole, inverse=False): # We now loop over every unique value of z and compute the transform. # This means we make the minimum number of calls possible to the transform. - for zind in np.unique(ind): + z_range = np.unique(zind) + for zzind in z_range: # Scalar parameters are reshaped to be length one arrays by modeling - sct = self.transform_at_index(zind, crpix[0], cdelt[0], lon_pole[0]) + sct = self.transform_at_index(zzind, crpix[0], cdelt[0], lon_pole[0]) # Call this transform for all values of x, y where z == zind - mask = ind == zind + mask = zind == zzind if inverse: xx, yy = sct.inverse(bx[mask], by[mask]) else: @@ -323,8 +324,10 @@ def _map_transform(self, x, y, z, q, crpix, cdelt, lon_pole, inverse=False): # We now loop over every unique value of z and compute the transform. # This means we make the minimum number of calls possible to the transform. - for qq in np.unique(qind): - for zz in np.unique(zind): + z_range = np.unique(zind) # raster + q_range = np.unique(qind) # other: maps or meas + for zz in z_range: + for qq in q_range: # Scalar parameters are reshaped to be length one arrays by modeling sct = self.transform_at_index((zz, qq), crpix[0], cdelt[0], lon_pole[0]) @@ -416,15 +419,17 @@ def _map_transform(self, x, y, m, z, q, crpix, cdelt, lon_pole, inverse=False): # We now loop over every unique value of z and compute the transform. # This means we make the minimum number of calls possible to the transform. - for mm in np.unique(mind): - for qq in np.unique(qind): - for zz in np.unique(zind): + m_range = np.unique(mind) # raster + q_range = np.unique(qind) # maps + z_range = np.unique(zind) # meas + for mm in m_range: + for zz in z_range: + for qq in q_range: # Scalar parameters are reshaped to be length one arrays by modeling sct = self.transform_at_index((mm, zz, qq), crpix[0], cdelt[0], lon_pole[0]) # Call this transform for all values of x, y where z == zind q == qind and m == mind - mask = np.logical_and(mind == mm, zind == zz, qind == qq) - # TODO: Figure out what to do here... + mask = np.logical_and(np.logical_and(mind == mm, zind == zz), qind == qq) if inverse: xx, yy = sct.inverse(bx[mask], by[mask]) else: @@ -476,7 +481,6 @@ def inverse(self): return ivct -# TODO: Test class InverseVaryingCelestialTransform3D(BaseVaryingCelestialTransform3D): n_inputs = 5 n_outputs = 2 @@ -497,6 +501,7 @@ def evaluate(self, lon, lat, m, z, q, crpix, cdelt, lon_pole, **kwargs): class BaseVaryingCelestialTransformSlit(BaseVaryingCelestialTransform, ABC): def _map_transform(self, x, y, crpix, cdelt, lon_pole, inverse=False): + # x: along slit, y: raster position, both in pixels # We need to broadcast the arrays together so they are all the same shape bx, by = np.broadcast_arrays(x, y, subok=True) # Convert the z coordinate into an index to the lookup tables @@ -512,14 +517,15 @@ def _map_transform(self, x, y, crpix, cdelt, lon_pole, inverse=False): # We now loop over every unique value of y and compute the transform. # This means we make the minimum number of calls possible to the transform. - for yyind in np.unique(yind): + y_range = np.unique(yind) + for yyind in y_range: # Scalar parameters are reshaped to be length one arrays by modeling sct = self.transform_at_index(yyind, crpix[0], cdelt[0], lon_pole[0]) # Call this transform for all values of x, y where y == yind mask = yind == yyind if inverse: - # Note this isn't use as the inverse of this model uses the + # Note this isn't used as the inverse of this model uses the # standard 2D inverse. xx, yy = sct.inverse(bx[mask], by[mask]) # pragma: no cover else: @@ -636,10 +642,12 @@ def _map_transform(self, x, y, z, crpix, cdelt, lon_pole, inverse=False): # We now loop over every unique value of y and z and compute the transform. # This means we make the minimum number of calls possible to the transform. - for yyind in np.unique(yind): - for zzind in np.unique(zind): + y_range = np.unique(yind) + z_range = np.unique(zind) + for yyind in y_range: + for zzind in z_range: # Scalar parameters are reshaped to be length one arrays by modeling - sct = self.transform_at_index((zzind, yyind), crpix[0], cdelt[0], lon_pole[0]) + sct = self.transform_at_index((yyind, zzind), crpix[0], cdelt[0], lon_pole[0]) # Call this transform for all values of x, y where z == zind mask = np.logical_and(zind == zzind, yind == yyind) @@ -689,7 +697,6 @@ def __init__(self, *args, **kwargs): self.inputs = ("along_slit", "meas_num", "raster", "repeat") self.outputs = ("lon", "lat") - # TODO: Will this work for 2D if num_meas == 1? if len(self.table_shape) != 3: raise ValueError("This model can only be constructed with a three dimensional lookup table.") @@ -730,19 +737,20 @@ def _map_transform(self, x, m, y, z, crpix, cdelt, lon_pole, inverse=False): # We now loop over every unique value of y and z and compute the transform. # This means we make the minimum number of calls possible to the transform. - # TODO: why this particular order? - for mmind in np.unique(mind): - for yyind in np.unique(yind): - for zzind in np.unique(zind): + m_range = np.unique(mind) + y_range = np.unique(yind) + z_range = np.unique(zind) + for mmind in m_range: + for yyind in y_range: + for zzind in z_range: # Scalar parameters are reshaped to be length one arrays by modeling - sct = self.transform_at_index((mmind, zzind, yyind), crpix[0], cdelt[0], lon_pole[0]) + sct = self.transform_at_index((mmind, yyind, zzind), crpix[0], cdelt[0], lon_pole[0]) # Call this transform for all values of x, y where m == mind, y == yind and z == zind - mask = np.logical_and(mmind == mind, zind == zzind, yind == yyind) + mask = np.logical_and(np.logical_and(mmind == mind, zind == zzind), yind == yyind) if inverse: - # TODO: figure out what to do here... - # Note this isn't use as the inverse of this model uses the - # standard 2D inverse. + # Note this isn't used as the inverse of this model uses the + # standard 3D inverse. xx, yy = sct.inverse(bx[mask], by[mask]) # pragma: no cover else: xx, yy = sct(bx[mask], by[mask]) @@ -759,7 +767,6 @@ def _map_transform(self, x, m, y, z, crpix, cdelt, lon_pole, inverse=False): return x_out, y_out -# TODO: Test class InverseVaryingCelestialTransformSlit3D(BaseVaryingCelestialTransform3D): n_inputs = 5 n_outputs = 1 diff --git a/dkist/wcs/tests/test_models.py b/dkist/wcs/tests/test_models.py index 62b93516..8541a30a 100755 --- a/dkist/wcs/tests/test_models.py +++ b/dkist/wcs/tests/test_models.py @@ -1,3 +1,5 @@ +import copy + import numpy as np import pytest from numpy.random import default_rng @@ -53,11 +55,12 @@ def test_varying_transform_no_lon_pole_unit(): varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 10)] * u.arcsec # Without a lon_pole passed, the transform was originally setting # the unit for it to be the same as crval, which is wrong. - vct = VaryingCelestialTransform(crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - crval_table=(0, 0) * u.arcsec, - pc_table=varying_matrix_lt, - ) + vct = VaryingCelestialTransform( + crpix=(5, 5) * u.pix, + cdelt=(1, 1) * u.arcsec/u.pix, + crval_table=(0, 0) * u.arcsec, + pc_table=varying_matrix_lt, + ) trans5 = vct.transform_at_index(5) assert isinstance(trans5, CompoundModel) assert u.allclose(trans5.right.lon_pole, 180 * u.deg) @@ -65,11 +68,13 @@ def test_varying_transform_no_lon_pole_unit(): def test_varying_transform_pc(): varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 10)] * u.arcsec - vct = VaryingCelestialTransform(crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - crval_table=(0, 0) * u.arcsec, - pc_table=varying_matrix_lt, - lon_pole=180 * u.deg) + vct = VaryingCelestialTransform( + crpix=(5, 5) * u.pix, + cdelt=(1, 1) * u.arcsec/u.pix, + crval_table=(0, 0) * u.arcsec, + pc_table=varying_matrix_lt, + lon_pole=180 * u.deg, + ) trans5 = vct.transform_at_index(5) assert isinstance(trans5, CompoundModel) @@ -98,11 +103,13 @@ def test_varying_transform_pc(): def test_varying_transform_pc_shapes(pixel, lon_shape): varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 10)] * u.arcsec - vct = VaryingCelestialTransform(crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - crval_table=(0, 0) * u.arcsec, - pc_table=varying_matrix_lt, - lon_pole=180 * u.deg) + vct = VaryingCelestialTransform( + crpix=(5, 5) * u.pix, + cdelt=(1, 1) * u.arcsec/u.pix, + crval_table=(0, 0) * u.arcsec, + pc_table=varying_matrix_lt, + lon_pole=180 * u.deg, + ) world = vct(*pixel) assert np.array(world[0]).shape == lon_shape new_pixel = vct.inverse(*world, pixel[-1]) @@ -114,12 +121,14 @@ def test_varying_transform_pc_shapes(pixel, lon_shape): def test_varying_transform_pc_unitless(): varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 10)] - vct = VaryingCelestialTransform(crpix=(5, 5), - # without units everything is deg, so make this arcsec in deg - cdelt=((1, 1)*u.arcsec).to_value(u.deg), - crval_table=(0, 0), - pc_table=varying_matrix_lt, - lon_pole=180) + vct = VaryingCelestialTransform( + crpix=(5, 5), + # without units everything is deg, so make this arcsec in deg + cdelt=((1, 1)*u.arcsec).to_value(u.deg), + crval_table=(0, 0), + pc_table=varying_matrix_lt, + lon_pole=180, + ) trans5 = vct.transform_at_index(5) assert isinstance(trans5, CompoundModel) @@ -138,11 +147,13 @@ def test_varying_transform_pc_unitless(): def test_varying_transform_crval(): crval_table = ((0, 1), (2, 3), (4, 5)) * u.arcsec - vct = VaryingCelestialTransform(crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - crval_table=crval_table, - pc_table=np.identity(2) * u.arcsec, - lon_pole=180 * u.deg) + vct = VaryingCelestialTransform( + crpix=(5, 5) * u.pix, + cdelt=(1, 1) * u.arcsec/u.pix, + crval_table=crval_table, + pc_table=np.identity(2) * u.arcsec, + lon_pole=180 * u.deg, + ) trans2 = vct.transform_at_index(2) assert isinstance(trans2, CompoundModel) @@ -162,55 +173,65 @@ def test_varying_transform_crval(): def test_vct_errors(): with pytest.raises(ValueError, match="pc table should be"): - VaryingCelestialTransform(crpix=(5, 5), - cdelt=(1, 1), - crval_table=(0, 0), - pc_table=[[1, 2, 3], - [4, 4, 6], - [7, 8, 9]], - lon_pole=180) + VaryingCelestialTransform( + crpix=(5, 5), + cdelt=(1, 1), + crval_table=(0, 0), + pc_table=[[1, 2, 3], + [4, 4, 6], + [7, 8, 9]], + lon_pole=180, + ) with pytest.raises(ValueError, match="crval table should be"): - VaryingCelestialTransform(crpix=(5, 5), - cdelt=(1, 1), - crval_table=((0, 0, 0),), - pc_table=[[1, 2], - [4, 4]], - lon_pole=180) + VaryingCelestialTransform( + crpix=(5, 5), + cdelt=(1, 1), + crval_table=((0, 0, 0),), + pc_table=[[1, 2], + [4, 4]], + lon_pole=180, + ) with pytest.raises(ValueError, match="The shape of the pc and crval tables should match"): - VaryingCelestialTransform(crpix=(5, 5), - cdelt=(1, 1), - crval_table=((0, 0), (1, 1), (2, 2)), - pc_table=[[[1, 2], - [4, 4]], - [[1, 2], - [3, 4]]], - lon_pole=180) + VaryingCelestialTransform( + crpix=(5, 5), + cdelt=(1, 1), + crval_table=((0, 0), (1, 1), (2, 2)), + pc_table=[[[1, 2], + [4, 4]], + [[1, 2], + [3, 4]]], + lon_pole=180, + ) with pytest.raises(TypeError, match="projection keyword should"): - VaryingCelestialTransform(crpix=(5, 5), - cdelt=(1, 1), - crval_table=((0, 0), (1, 1)), - pc_table=[[[1, 2], - [4, 4]], - [[1, 2], - [3, 4]]], - lon_pole=180, - projection=ValueError) + VaryingCelestialTransform( + crpix=(5, 5), + cdelt=(1, 1), + crval_table=((0, 0), (1, 1)), + pc_table=[[[1, 2], + [4, 4]], + [[1, 2], + [3, 4]]], + lon_pole=180, + projection=ValueError, + ) def test_varying_transform_4d_pc(): varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 15)] * u.arcsec varying_matrix_lt = varying_matrix_lt.reshape((3, 5, 2, 2)) - vct = VaryingCelestialTransform2D(crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - crval_table=(0, 0) * u.arcsec, - pc_table=varying_matrix_lt, - lon_pole=180 * u.deg) + vct = VaryingCelestialTransform2D( + crpix=(5, 5) * u.pix, + cdelt=(1, 1) * u.arcsec/u.pix, + crval_table=(0, 0) * u.arcsec, + pc_table=varying_matrix_lt, + lon_pole=180 * u.deg, + ) pixel = 0*u.pix, 0*u.pix, 0*u.pix, 0*u.pix world = vct(*pixel) @@ -228,11 +249,13 @@ def test_varying_transform_4d_pc_unitless(): varying_matrix_lt = np.array([rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 15)]) varying_matrix_lt = varying_matrix_lt.reshape((3, 5, 2, 2)) - vct = VaryingCelestialTransform2D(crpix=(5, 5), - cdelt=(1, 1), - crval_table=(0, 0), - pc_table=varying_matrix_lt, - lon_pole=180) + vct = VaryingCelestialTransform2D( + crpix=(5, 5), + cdelt=(1, 1), + crval_table=(0, 0), + pc_table=varying_matrix_lt, + lon_pole=180, + ) pixel = 0, 0, 0, 0 world = vct(*pixel) @@ -255,11 +278,13 @@ def test_varying_transform_4d_pc_shapes(pixel, lon_shape): varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 15)] * u.arcsec varying_matrix_lt = varying_matrix_lt.reshape((5, 3, 2, 2)) - vct = VaryingCelestialTransform2D(crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - crval_table=(0, 0) * u.arcsec, - pc_table=varying_matrix_lt, - lon_pole=180 * u.deg) + vct = VaryingCelestialTransform2D( + crpix=(5, 5) * u.pix, + cdelt=(1, 1) * u.arcsec/u.pix, + crval_table=(0, 0) * u.arcsec, + pc_table=varying_matrix_lt, + lon_pole=180 * u.deg, + ) world = vct(*pixel) assert np.array(world[0]).shape == lon_shape new_pixel = vct.inverse(*world, *pixel[-2:]) @@ -273,9 +298,11 @@ def test_vct_dispatch(): varying_matrix_lt = varying_matrix_lt.reshape((2, 2, 2, 2, 2, 2)) crval_table = list(zip(np.arange(1, 17), np.arange(17, 33))) * u.arcsec crval_table = crval_table.reshape((2, 2, 2, 2, 2)) - kwargs = dict(crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - lon_pole=180 * u.deg) + kwargs = dict( + crpix=(5, 5) * u.pix, + cdelt=(1, 1) * u.arcsec/u.pix, + lon_pole=180 * u.deg, + ) vct = varying_celestial_transform_from_tables( pc_table=varying_matrix_lt[0, 0, 0], @@ -313,9 +340,11 @@ def test_vct_shape_errors(): crval_table = list(zip(np.arange(1, 16), np.arange(16, 31))) * u.arcsec crval_table = crval_table.reshape((5, 3, 2)) - kwargs = dict(crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - lon_pole=180 * u.deg) + kwargs = dict( + crpix=(5, 5) * u.pix, + cdelt=(1, 1) * u.arcsec/u.pix, + lon_pole=180 * u.deg, + ) with pytest.raises(ValueError, match="only be constructed with a one dimensional"): VaryingCelestialTransform(crval_table=crval_table, pc_table=pc_table, **kwargs) @@ -329,87 +358,84 @@ def test_vct_shape_errors(): with pytest.raises(ValueError, match="only be constructed with a two dimensional"): VaryingCelestialTransformSlit2D(crval_table=crval_table[0], pc_table=pc_table[0], **kwargs) + with pytest.raises(ValueError, match="only be constructed with a three dimensional"): + VaryingCelestialTransform3D(crval_table=crval_table[0], pc_table=pc_table[0], **kwargs) -def test_vct_slit(): - pc_table = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 10)] * u.arcsec + with pytest.raises(ValueError, match="only be constructed with a three dimensional"): + VaryingCelestialTransformSlit3D(crval_table=crval_table[0], pc_table=pc_table[0], **kwargs) - kwargs = dict(crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - crval_table=(0, 0) * u.arcsec, - lon_pole=180 * u.deg) - - vct_slit = VaryingCelestialTransformSlit(pc_table=pc_table, **kwargs) - pixel = (0*u.pix, 0*u.pix) - world = vct_slit(*pixel) - ipixel = vct_slit.inverse(*world, 0*u.pix) - assert u.allclose(ipixel, pixel[0], atol=1e-5*u.pix) - - world2 = vct_slit(pixel[0], 1*u.pix) - assert not u.allclose(world, world2) - - -def test_vct_slit2d(): - pc_table = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 15)] * u.arcsec - pc_table = pc_table.reshape((5, 3, 2, 2)) - - kwargs = dict(crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - crval_table=(0, 0) * u.arcsec, - lon_pole=180 * u.deg) - - vct_slit = VaryingCelestialTransformSlit2D(pc_table=pc_table, **kwargs) - pixel = (0*u.pix, 0*u.pix, 0*u.pix) - world = vct_slit(*pixel) - ipixel = vct_slit.inverse(*world, 0*u.pix, 0*u.pix) - assert u.allclose(ipixel, pixel[0], atol=1e-5*u.pix) - -def test_vct_slit3d(): - pc_table = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 30)] * u.arcsec - pc_table = pc_table.reshape((5, 2, 3, 2, 2)) - - kwargs = dict(crpix=(5, 5) * u.pix, - cdelt=(1, 1) * u.arcsec/u.pix, - crval_table=(0, 0) * u.arcsec, - lon_pole=180 * u.deg) - - vct_slit = VaryingCelestialTransformSlit3D(pc_table=pc_table, **kwargs) - pixel = (0*u.pix, 0*u.pix, 0*u.pix, 0*u.pix) - world = vct_slit(*pixel) - ipixel = vct_slit.inverse(*world, 0*u.pix, 0*u.pix, 0*u.pix) - assert u.allclose(ipixel, pixel[0], atol=1e-5*u.pix) - -def test_vct_slit_unitless(): - pc_table = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 10)] - - kwargs = dict(crpix=(5, 5), - cdelt=(1, 1), - crval_table=(0, 0), - lon_pole=180) - - vct_slit = VaryingCelestialTransformSlit(pc_table=pc_table, **kwargs) - pixel = (0, 0) - world = vct_slit(*pixel) - ipixel = vct_slit.inverse(*world, 0) - assert u.allclose(ipixel, pixel[0], atol=1e-5) - - world2 = vct_slit(pixel[0], 1) - assert not u.allclose(world, world2) - - -def test_vct_slit2d_unitless(): - pc_table = np.array([rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 15)]) - pc_table = pc_table.reshape((5, 3, 2, 2)) - - kwargs = dict(crpix=(5, 5), - cdelt=(1, 1), - crval_table=(0, 0), - lon_pole=180) +@pytest.mark.parametrize("num_varying_axes", [pytest.param(1, id='1D'), pytest.param(2, id='2D'), pytest.param(3, id='3D')]) +@pytest.mark.parametrize("slit", [pytest.param(True, id="spectrograph"), pytest.param(False, id="imager")]) +@pytest.mark.parametrize("has_units", [pytest.param(True, id="With Units"), pytest.param(False, id="Without Units")]) +def test_vct(has_units, slit, num_varying_axes): + if slit: + num_sensor_axes = 1 + sensor_dims = [5] + else: + num_sensor_axes = 2 + sensor_dims = [4, 8] + varying_axis_dims = np.zeros(num_varying_axes, dtype=int) + # Create increasing dimension sizes starting with a value of 2 + for i in range(num_varying_axes - 1): + varying_axis_dims[i] = i + 2 + # num raster steps + varying_axis_dims[-1] = 10 + table_length = np.prod(varying_axis_dims) + pc_table = np.array([rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, table_length)]) + pc_table = pc_table.reshape((*varying_axis_dims, 2, 2)) + crpix = (3, 3) + cdelt = (1, 1) + crval_table = np.array([0, 0]) + lon_pole = 180 + varying_axis_pts = [np.arange(npts) for npts in varying_axis_dims] + # create a set of varying axis points with a raster vector that is outside the original table + varying_axis_pts_1 = copy.deepcopy(varying_axis_pts) + varying_axis_pts_1[-1] += 1 + sensor_axis_pts = [np.arange(item) for item in sensor_dims] + atol = 1.e-5 + if has_units: + pc_table *= u.arcsec + crpix *= u.pix + cdelt *= u.arcsec / u.pix + crval_table *= u.arcsec + lon_pole *= u.deg + for i in range(num_varying_axes): + varying_axis_pts[i] *= u.pix + for i in range(num_varying_axes): + varying_axis_pts_1[i] *= u.pix + atol *= u.pix + for i in range(num_sensor_axes): + sensor_axis_pts[i] *= u.pix + grid = np.meshgrid(*sensor_axis_pts, *varying_axis_pts, indexing='ij') + grid2 = np.meshgrid(*sensor_axis_pts, *varying_axis_pts_1, indexing='ij') + # the portion of the grid due to the varying axes coordinates + varying_axes_grid = grid[num_sensor_axes:] - vct_slit = VaryingCelestialTransformSlit2D(pc_table=pc_table, **kwargs) - pixel = (0, 0, 0) - world = vct_slit(*pixel) - ipixel = vct_slit.inverse(*world, 0, 0) - assert u.allclose(ipixel, pixel[0], atol=1e-5) + vct = varying_celestial_transform_from_tables( + crpix=crpix, + cdelt=cdelt, + pc_table=pc_table, + crval_table=crval_table, + lon_pole=lon_pole, + slit=slit, + ) + # forward transform returns lat and long vectors + world = vct(*grid) + assert len(world) == 2 + grid_shape = grid[0].shape + assert np.all([world_item.shape == grid_shape for world_item in world]) + assert np.all([grid_item.shape == grid_shape for grid_item in grid]) + # there should be no nans in world: + assert not np.any(np.isnan(world)) + # reverse transform to get round trip + ipixel = vct.inverse(*world, *varying_axes_grid) + # round trip should be the same as what we started with + # grid[:num_sensor_axes] is the set of on-sensor coordinates + assert u.allclose(ipixel, grid[:num_sensor_axes], atol=atol) + + # grid2 has coordinates outside the lut boundaries and should have nans + world2 = vct(*grid2) + assert np.any(np.isnan([item for item in world2])) def _evaluate_ravel(array_shape, inputs, order="C"): @@ -504,7 +530,11 @@ def test_raveled_tabular1d(ndim, has_units, input_type): values *= units lut_values = values tabular = Tabular1D( - values, lut_values, bounds_error=False, fill_value=np.nan, method="linear" + values, + lut_values, + bounds_error=False, + fill_value=np.nan, + method="linear", ) raveled_tab = ravel | tabular if input_type == "array":