Skip to content

Commit

Permalink
Improve performance of VaryingCelestialTransform (DKISTDC#370)
Browse files Browse the repository at this point in the history
* Change params on CompoundModel to avoid looping over model creation

* Take units out of evaulate loop

* Put units back on for inverse as well

* Enforce strict units

* Yeah you can go ahead and burn all that

* Strip the units somewhere else instead and also fix them

* More units fiddling

* generate_celestial_transform needs to take a transform now

* Assorted units fixes for tests now there are units inside the transform

* Make the tests pass, but at what cost

* Add a changelog

---------

Co-authored-by: Stuart Mumford <[email protected]>
  • Loading branch information
SolarDrew and Cadair authored May 17, 2024
1 parent 6f7ce1e commit fdcf97d
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 36 deletions.
1 change: 1 addition & 0 deletions changelog/370.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve performance of ``VaryingCelestialTransform`` classes by not creating a new transform for every set of parameters but instead update the parameters on a single model.
153 changes: 122 additions & 31 deletions dkist/wcs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def generate_celestial_transform(
----------
crpix
The reference pixel (a length two array).
crval
The world coordinate at the reference pixel (a length two array).
cdelt
The sample interval along the pixel axis
pc
The rotation matrix for the affine transform. If specifying parameters
with units this should have celestial (``u.deg``) units.
crval
The world coordinate at the reference pixel (a length two array).
lon_pole
The longitude of the celestial pole, defaults to 180 degrees.
projection
Expand Down Expand Up @@ -97,6 +97,53 @@ def generate_celestial_transform(
return shift | rot | scale | projection | skyrot


def update_celestial_transform_parameters(
transform,
crpix: Iterable[float] | u.Quantity,
cdelt: Iterable[float] | u.Quantity,
pc: ArrayLike | u.Quantity,
crval: Iterable[float] | u.Quantity,
lon_pole: float | u.Quantity,
) -> CompoundModel:
"""
Update an existing transform with new parameter values.
.. warning::
This assumes that the type (quantity vs not quantity) and units of the
new parameters are valid for the model.
Parameters
----------
crpix
The reference pixel (a length two array).
cdelt
The sample interval along the pixel axis
pc
The rotation matrix for the affine transform. If specifying parameters
with units this should have celestial (``u.deg``) units.
crval
The world coordinate at the reference pixel (a length two array).
lon_pole
The longitude of the celestial pole, defaults to 180 degrees.
"""
new_params = [
-crpix[0],
-crpix[1],
pc,
transform[2].translation.value,
cdelt[0],
cdelt[1],
crval[0],
crval[1],
lon_pole
]

for name, val in zip(transform.param_names, new_params):
setattr(transform, name, val)

return transform


class BaseVaryingCelestialTransform(Model, ABC):
"""
Shared components between the forward and reverse varying celestial transforms.
Expand All @@ -106,6 +153,7 @@ class BaseVaryingCelestialTransform(Model, ABC):
standard_broadcasting = False
_separable = False
_input_units_allow_dimensionless = True
_input_units_strict = True
_is_inverse = False

crpix = Parameter()
Expand Down Expand Up @@ -169,14 +217,23 @@ def __init__(self, *args, crval_table=None, pc_table=None, projection=m.Pix2Sky_
if len(self.table_shape) != self.n_inputs-2:
raise ValueError(f"This model can only be constructed with a {self.n_inputs-2}-dimensional lookup table.")

self._transform = generate_celestial_transform(
crpix=[0, 0],
cdelt=[1, 1],
pc=np.identity(2),
crval=[0, 0],
lon_pole=180,
projection=projection,
)

def transform_at_index(self, ind, crpix=None, cdelt=None, lon_pole=None):
"""
Generate a spatial model based on an index for the pc and crval tables.
Parameters
----------
zind : int
The index to the lookup table.Q
The index to the lookup table.
**kwargs
The keyword arguments are optional and if not specified will be read
Expand All @@ -187,27 +244,60 @@ def transform_at_index(self, ind, crpix=None, cdelt=None, lon_pole=None):
`astropy.modeling.CompoundModel`
"""
# If we are being called from inside evaluate we can skip the lookup
crpix = crpix if crpix is not None else self.crpix
cdelt = cdelt if cdelt is not None else self.cdelt
lon_pole = lon_pole if lon_pole is not None else self.lon_pole

# If we are out of bounds of the lookup table return a constant model
fill_val = np.nan
if isinstance(crpix, u.Quantity):
fill_val = np.nan * u.deg
if (np.array(ind) > np.array(self.table_shape) - 1).any() or (np.array(ind) < 0).any():
return m.Const1D(fill_val) & m.Const1D(fill_val)

return generate_celestial_transform(
# The self._transform is always unitless and always in degrees.
# So we need to strip down the parameters to be in the correct units

# If we are being called from inside evaluate we can skip the lookup
# but we have to handle both dimensionless and unitful parameters
if crpix is None:
if self.crpix.unit is not None:
crpix = u.Quantity(self.crpix)
else:
crpix = self.crpix.value
if cdelt is None:
if self.cdelt.unit is not None:
cdelt = u.Quantity(self.cdelt)
else:
cdelt = self.cdelt.value
if lon_pole is None:
if self.lon_pole.unit is not None:
lon_pole = u.Quantity(self.lon_pole)
else:
lon_pole = self.lon_pole.value

if isinstance(self.pc_table, u.Quantity):
pc = self.pc_table[ind].to_value(u.pix)
else:
pc = self.pc_table[ind]

if isinstance(self.crval_table, u.Quantity):
crval = self.crval_table[ind].to_value(u.deg)
else:
crval = self.crval_table[ind]

if isinstance(crpix, u.Quantity):
crpix = crpix.to_value(u.pix)

if isinstance(cdelt, u.Quantity):
cdelt = cdelt.to_value(u.deg / u.pix)

if isinstance(lon_pole, u.Quantity):
lon_pole = lon_pole.to_value(u.deg)

return update_celestial_transform_parameters(
self._transform,
crpix=crpix,
cdelt=cdelt,
pc=self.pc_table[ind],
crval=self.crval_table[ind],
pc=pc,
crval=crval,
lon_pole=lon_pole,
projection=self.projection,
)


def _map_transform(self, *arrays, crpix, cdelt, lon_pole, inverse=False):
# We need to broadcast the arrays together so they are all the same shape
barrays = np.broadcast_arrays(*arrays, subok=True)
Expand All @@ -216,13 +306,13 @@ def _map_transform(self, *arrays, crpix, cdelt, lon_pole, inverse=False):
for barray in barrays[2:]:
inds.append(self.sanitize_index(barray))

# Generate output arrays (ignore units for simplicity)
if isinstance(barrays[0], u.Quantity):
x_out = np.empty_like(barrays[0].value)
y_out = np.empty_like(barrays[1].value)
else:
x_out = np.empty_like(barrays[0])
y_out = np.empty_like(barrays[1])
# Because we have set input_units_strict to True we can assume that
# all inputs have the correct units for the transform
arrays = [arr.value for arr in barrays]

x_out = np.empty_like(arrays[0])
y_out = np.empty_like(arrays[1])

# We now loop over every unique value of z and compute the transform.
# This means we make the minimum number of calls possible to the transform.
Expand All @@ -238,19 +328,20 @@ def _map_transform(self, *arrays, crpix, cdelt, lon_pole, inverse=False):
else:
mask = masks[0]
if inverse:
xx, yy = sct.inverse(barrays[0][mask], barrays[1][mask])
xx, yy = sct.inverse(arrays[0][mask], arrays[1][mask])
else:
xx, yy = sct(barrays[0][mask], barrays[1][mask])
xx, yy = sct(arrays[0][mask], arrays[1][mask])

if isinstance(xx, u.Quantity):
x_out[mask], y_out[mask] = xx.value, yy.value
else:
x_out[mask], y_out[mask] = xx, yy
x_out[mask], y_out[mask] = xx, yy

# Put the units back
if isinstance(xx, u.Quantity):
x_out = x_out << xx.unit
y_out = y_out << yy.unit
# Put the units back if we started with some
if isinstance(barrays[0], u.Quantity):
if self._is_inverse:
x_out = x_out << u.pix
y_out = y_out << u.pix
else:
x_out = x_out << u.deg
y_out = y_out << u.deg

return x_out, y_out

Expand Down
10 changes: 5 additions & 5 deletions dkist/wcs/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_varying_transform_no_lon_pole_unit():
)
trans5 = vct.transform_at_index(5)
assert isinstance(trans5, CompoundModel)
assert u.allclose(trans5.right.lon_pole, 180 * u.deg)
assert u.allclose(trans5.right.lon_pole, 180)


def test_varying_transform_pc():
Expand All @@ -84,7 +84,7 @@ def test_varying_transform_pc():
# Verify that we have the 5th matrix in the series
affine = next(filter(lambda sm: isinstance(sm, m.AffineTransformation2D), trans5.traverse_postorder()))
assert isinstance(affine, m.AffineTransformation2D)
assert u.allclose(affine.matrix, varying_matrix_lt[5])
assert u.allclose(affine.matrix, varying_matrix_lt[5].value)
# x.shape=(1,), y.shape=(1,), z.shape=(1,)
pixel = (0*u.pix, 0*u.pix, 5*u.pix)
world = vct(*pixel)
Expand Down Expand Up @@ -163,8 +163,8 @@ def test_varying_transform_crval():
# Verify that we have the 2nd crval pair in the series
crval1 = trans2.right.lon
crval2 = trans2.right.lat
assert u.allclose(crval1, crval_table[2][0])
assert u.allclose(crval2, crval_table[2][1])
assert u.allclose(crval1, crval_table[2][0].to_value(u.deg))
assert u.allclose(crval2, crval_table[2][1].to_value(u.deg))

pixel = (0*u.pix, 0*u.pix, 2*u.pix)
world = vct(*pixel)
Expand Down Expand Up @@ -274,7 +274,7 @@ def test_varying_transform_4d_pc_unitless():
((np.arange(10) * u.pix,
np.arange(10) * u.pix,
np.arange(5)[..., None] * u.pix,
np.arange(3)[..., None, None]), (3, 5, 10)),
np.arange(3)[..., None, None] * u.pix), (3, 5, 10)),
])
def test_varying_transform_4d_pc_shapes(pixel, lon_shape):
varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 15)] * u.pix
Expand Down

0 comments on commit fdcf97d

Please sign in to comment.