Skip to content

Commit

Permalink
Bring back inverse classes
Browse files Browse the repository at this point in the history
  • Loading branch information
SolarDrew committed Mar 22, 2024
1 parent b140961 commit 4eb7cda
Showing 1 changed file with 94 additions and 24 deletions.
118 changes: 94 additions & 24 deletions dkist/wcs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,20 @@ def sanitize_index(ind):
ind = ind.value
return np.array(np.round(ind), dtype=int)

def __init__(self, *args, crval_table=None, pc_table=None, projection=m.Pix2Sky_TAN(), inverse=False, **kwargs):
def __init__(self, *args, crval_table=None, pc_table=None, projection=m.Pix2Sky_TAN(), **kwargs):
super().__init__(*args, **kwargs)
(
self.table_shape,
self.pc_table,
self.crval_table,
) = self._validate_table_shapes(np.asanyarray(pc_table), np.asanyarray(crval_table))
self.is_inverse = inverse
self._is_inverse = False

if not isinstance(projection, m.Pix2SkyProjection):
raise TypeError("The projection keyword should be a Pix2SkyProjection model class.")
self.projection = projection

if self.is_inverse:
if self._is_inverse:
self.inputs = ("lon", "lat", "z", "q", "m")[:self.n_inputs]
self.outputs = ("x", "y")

Check warning on line 161 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L160-L161

Added lines #L160 - L161 were not covered by tests
else:
Expand All @@ -166,8 +166,6 @@ def __init__(self, *args, crval_table=None, pc_table=None, projection=m.Pix2Sky_
if len(self.table_shape) != self.n_inputs-2:
raise ValueError(f"This model can only be constructed with a {self.n_inputs-2}-dimensional lookup table.")


# @cache
def transform_at_index(self, ind, crpix=None, cdelt=None, lon_pole=None):
"""
Generate a spatial model based on an index for the pc and crval tables.
Expand Down Expand Up @@ -197,7 +195,6 @@ def transform_at_index(self, ind, crpix=None, cdelt=None, lon_pole=None):
if (np.array(ind) > np.array(self.table_shape) - 1).any() or (np.array(ind) < 0).any():
return m.Const1D(fill_val) & m.Const1D(fill_val)

print(f"Generating transform at {ind=}")
sct = generate_celestial_transform(
crpix=crpix,
cdelt=cdelt,
Expand Down Expand Up @@ -263,12 +260,12 @@ def evaluate(self, *inputs):
kwargs = inputs[self.n_inputs:]
keys = ["crpix", "cdelt", "lon_pole"]
kwargs = dict(zip(keys, kwargs))
return self._map_transform(*arrays, inverse=self.is_inverse, **kwargs)
return self._map_transform(*arrays, inverse=self._is_inverse, **kwargs)

@property
def input_units(self):
# NB: x and y are normally on the detector and z is typically the number of raster steps
if self.is_inverse:
if self._is_inverse:
dims = ["z", "q", "m"]
{d: u.pix for d in dims[:self.n_inputs]}
units = {"lon": u.deg, "lat": u.deg}
Expand All @@ -279,39 +276,113 @@ def input_units(self):
dims = ["x", "y", "z", "q", "m"]
return {d: u.pix for d in dims[:self.n_inputs]}


class VaryingCelestialTransform(BaseVaryingCelestialTransform):
"""
A celestial transform which can vary its pointing and rotation with time.
This model stores a lookup table for the reference pixel ``crval_table``
and the rotation matrix ``pc_table`` which are indexed with a third pixel
index (z).
The other parameters (``crpix``, ``cdelt``, and ``lon_pole``) are fixed.
"""
n_inputs = 3

@property
def inverse(self):
ivct = self.__class__(
ivct = InverseVaryingCelestialTransform(
crpix=self.crpix,
cdelt=self.cdelt,
lon_pole=self.lon_pole,
pc_table=self.pc_table,
crval_table=self.crval_table,
projection=self.projection,
inverse=True
)
return ivct


class VaryingCelestialTransform(BaseVaryingCelestialTransform):
"""
A celestial transform which can vary it's pointing and rotation with time.
class VaryingCelestialTransform2D(BaseVaryingCelestialTransform):
n_inputs = 4

This model stores a lookup table for the reference pixel ``crval_table``
and the rotation matrix ``pc_table`` which are indexed with a third pixel
index (z).
@property
def inverse(self):
ivct = InverseVaryingCelestialTransform2D(
crpix=self.crpix,
cdelt=self.cdelt,
lon_pole=self.lon_pole,
pc_table=self.pc_table,
crval_table=self.crval_table,
projection=self.projection,
)
return ivct

The other parameters (``crpix``, ``cdelt``, and ``lon_pole``) are fixed.
"""

class VaryingCelestialTransform3D(BaseVaryingCelestialTransform):
n_inputs = 5

@property
def inverse(self):
ivct = InverseVaryingCelestialTransform3D(
crpix=self.crpix,
cdelt=self.cdelt,
lon_pole=self.lon_pole,
pc_table=self.pc_table,
crval_table=self.crval_table,
projection=self.projection,
)
return ivct


class InverseVaryingCelestialTransform(BaseVaryingCelestialTransform):
n_inputs = 3
_is_inverse = True

@property
def inverse(self):
vct = VaryingCelestialTransform(

Check warning on line 343 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L343

Added line #L343 was not covered by tests
crpix=self.crpix,
cdelt=self.cdelt,
lon_pole=self.lon_pole,
pc_table=self.pc_table,
crval_table=self.crval_table,
projection=self.projection,
)
return vct

Check warning on line 351 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L351

Added line #L351 was not covered by tests


class VaryingCelestialTransform2D(BaseVaryingCelestialTransform):
class InverseVaryingCelestialTransform2D(BaseVaryingCelestialTransform):
n_inputs = 4
_is_inverse = True

@property
def inverse(self):
vct = VaryingCelestialTransform2D(

Check warning on line 360 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L360

Added line #L360 was not covered by tests
crpix=self.crpix,
cdelt=self.cdelt,
lon_pole=self.lon_pole,
pc_table=self.pc_table,
crval_table=self.crval_table,
projection=self.projection,
)
return vct

Check warning on line 368 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L368

Added line #L368 was not covered by tests


class VaryingCelestialTransform3D(BaseVaryingCelestialTransform):
class InverseVaryingCelestialTransform3D(BaseVaryingCelestialTransform):
n_inputs = 5
_is_inverse = True

@property
def inverse(self):
vct = VaryingCelestialTransform3D(

Check warning on line 377 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L377

Added line #L377 was not covered by tests
crpix=self.crpix,
cdelt=self.cdelt,
lon_pole=self.lon_pole,
pc_table=self.pc_table,
crval_table=self.crval_table,
projection=self.projection,
)
return vct

Check warning on line 385 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L385

Added line #L385 was not covered by tests


class CoupledCompoundModel(CompoundModel):
Expand Down Expand Up @@ -494,9 +565,9 @@ def __repr__(self):
(1, False): VaryingCelestialTransform,
(2, False): VaryingCelestialTransform2D,
(3, False): VaryingCelestialTransform3D,
(1, True): VaryingCelestialTransform,
(2, True): VaryingCelestialTransform2D,
(3, True): VaryingCelestialTransform3D,
(1, True): InverseVaryingCelestialTransform,
(2, True): InverseVaryingCelestialTransform2D,
(3, True): InverseVaryingCelestialTransform3D,
}

def varying_celestial_transform_from_tables(
Expand Down Expand Up @@ -529,7 +600,6 @@ def varying_celestial_transform_from_tables(
pc_table=pc_table,
lon_pole=lon_pole,
projection=projection,
inverse=inverse
)

# For slit models we duplicate one of the spatial pixel inputs to also be
Expand Down

0 comments on commit 4eb7cda

Please sign in to comment.