diff --git a/dkist/wcs/models.py b/dkist/wcs/models.py index 3813c997..d36a5d09 100755 --- a/dkist/wcs/models.py +++ b/dkist/wcs/models.py @@ -143,20 +143,20 @@ def sanitize_index(ind): ind = ind.value return np.array(np.round(ind), dtype=int) - def __init__(self, *args, crval_table=None, pc_table=None, projection=m.Pix2Sky_TAN(), inverse=False, **kwargs): + def __init__(self, *args, crval_table=None, pc_table=None, projection=m.Pix2Sky_TAN(), **kwargs): super().__init__(*args, **kwargs) ( self.table_shape, self.pc_table, self.crval_table, ) = self._validate_table_shapes(np.asanyarray(pc_table), np.asanyarray(crval_table)) - self.is_inverse = inverse + self._is_inverse = False if not isinstance(projection, m.Pix2SkyProjection): raise TypeError("The projection keyword should be a Pix2SkyProjection model class.") self.projection = projection - if self.is_inverse: + if self._is_inverse: self.inputs = ("lon", "lat", "z", "q", "m")[:self.n_inputs] self.outputs = ("x", "y") else: @@ -166,8 +166,6 @@ def __init__(self, *args, crval_table=None, pc_table=None, projection=m.Pix2Sky_ if len(self.table_shape) != self.n_inputs-2: raise ValueError(f"This model can only be constructed with a {self.n_inputs-2}-dimensional lookup table.") - - # @cache def transform_at_index(self, ind, crpix=None, cdelt=None, lon_pole=None): """ Generate a spatial model based on an index for the pc and crval tables. @@ -197,7 +195,6 @@ def transform_at_index(self, ind, crpix=None, cdelt=None, lon_pole=None): if (np.array(ind) > np.array(self.table_shape) - 1).any() or (np.array(ind) < 0).any(): return m.Const1D(fill_val) & m.Const1D(fill_val) - print(f"Generating transform at {ind=}") sct = generate_celestial_transform( crpix=crpix, cdelt=cdelt, @@ -263,12 +260,12 @@ def evaluate(self, *inputs): kwargs = inputs[self.n_inputs:] keys = ["crpix", "cdelt", "lon_pole"] kwargs = dict(zip(keys, kwargs)) - return self._map_transform(*arrays, inverse=self.is_inverse, **kwargs) + return self._map_transform(*arrays, inverse=self._is_inverse, **kwargs) @property def input_units(self): # NB: x and y are normally on the detector and z is typically the number of raster steps - if self.is_inverse: + if self._is_inverse: dims = ["z", "q", "m"] {d: u.pix for d in dims[:self.n_inputs]} units = {"lon": u.deg, "lat": u.deg} @@ -279,39 +276,113 @@ def input_units(self): dims = ["x", "y", "z", "q", "m"] return {d: u.pix for d in dims[:self.n_inputs]} + +class VaryingCelestialTransform(BaseVaryingCelestialTransform): + """ + A celestial transform which can vary its pointing and rotation with time. + + This model stores a lookup table for the reference pixel ``crval_table`` + and the rotation matrix ``pc_table`` which are indexed with a third pixel + index (z). + + The other parameters (``crpix``, ``cdelt``, and ``lon_pole``) are fixed. + """ + n_inputs = 3 + @property def inverse(self): - ivct = self.__class__( + ivct = InverseVaryingCelestialTransform( crpix=self.crpix, cdelt=self.cdelt, lon_pole=self.lon_pole, pc_table=self.pc_table, crval_table=self.crval_table, projection=self.projection, - inverse=True ) return ivct -class VaryingCelestialTransform(BaseVaryingCelestialTransform): - """ - A celestial transform which can vary it's pointing and rotation with time. +class VaryingCelestialTransform2D(BaseVaryingCelestialTransform): + n_inputs = 4 - This model stores a lookup table for the reference pixel ``crval_table`` - and the rotation matrix ``pc_table`` which are indexed with a third pixel - index (z). + @property + def inverse(self): + ivct = InverseVaryingCelestialTransform2D( + crpix=self.crpix, + cdelt=self.cdelt, + lon_pole=self.lon_pole, + pc_table=self.pc_table, + crval_table=self.crval_table, + projection=self.projection, + ) + return ivct - The other parameters (``crpix``, ``cdelt``, and ``lon_pole``) are fixed. - """ + +class VaryingCelestialTransform3D(BaseVaryingCelestialTransform): + n_inputs = 5 + + @property + def inverse(self): + ivct = InverseVaryingCelestialTransform3D( + crpix=self.crpix, + cdelt=self.cdelt, + lon_pole=self.lon_pole, + pc_table=self.pc_table, + crval_table=self.crval_table, + projection=self.projection, + ) + return ivct + + +class InverseVaryingCelestialTransform(BaseVaryingCelestialTransform): n_inputs = 3 + _is_inverse = True + + @property + def inverse(self): + vct = VaryingCelestialTransform( + crpix=self.crpix, + cdelt=self.cdelt, + lon_pole=self.lon_pole, + pc_table=self.pc_table, + crval_table=self.crval_table, + projection=self.projection, + ) + return vct -class VaryingCelestialTransform2D(BaseVaryingCelestialTransform): +class InverseVaryingCelestialTransform2D(BaseVaryingCelestialTransform): n_inputs = 4 + _is_inverse = True + + @property + def inverse(self): + vct = VaryingCelestialTransform2D( + crpix=self.crpix, + cdelt=self.cdelt, + lon_pole=self.lon_pole, + pc_table=self.pc_table, + crval_table=self.crval_table, + projection=self.projection, + ) + return vct -class VaryingCelestialTransform3D(BaseVaryingCelestialTransform): +class InverseVaryingCelestialTransform3D(BaseVaryingCelestialTransform): n_inputs = 5 + _is_inverse = True + + @property + def inverse(self): + vct = VaryingCelestialTransform3D( + crpix=self.crpix, + cdelt=self.cdelt, + lon_pole=self.lon_pole, + pc_table=self.pc_table, + crval_table=self.crval_table, + projection=self.projection, + ) + return vct class CoupledCompoundModel(CompoundModel): @@ -494,9 +565,9 @@ def __repr__(self): (1, False): VaryingCelestialTransform, (2, False): VaryingCelestialTransform2D, (3, False): VaryingCelestialTransform3D, - (1, True): VaryingCelestialTransform, - (2, True): VaryingCelestialTransform2D, - (3, True): VaryingCelestialTransform3D, + (1, True): InverseVaryingCelestialTransform, + (2, True): InverseVaryingCelestialTransform2D, + (3, True): InverseVaryingCelestialTransform3D, } def varying_celestial_transform_from_tables( @@ -529,7 +600,6 @@ def varying_celestial_transform_from_tables( pc_table=pc_table, lon_pole=lon_pole, projection=projection, - inverse=inverse ) # For slit models we duplicate one of the spatial pixel inputs to also be