From fdcf97d46cf3296ae5b604203f60d2d3f900c139 Mon Sep 17 00:00:00 2001 From: Drew Leonard Date: Fri, 17 May 2024 14:56:00 +0100 Subject: [PATCH] Improve performance of VaryingCelestialTransform (#370) * Change params on CompoundModel to avoid looping over model creation * Take units out of evaulate loop * Put units back on for inverse as well * Enforce strict units * Yeah you can go ahead and burn all that * Strip the units somewhere else instead and also fix them * More units fiddling * generate_celestial_transform needs to take a transform now * Assorted units fixes for tests now there are units inside the transform * Make the tests pass, but at what cost * Add a changelog --------- Co-authored-by: Stuart Mumford --- changelog/370.feature.rst | 1 + dkist/wcs/models.py | 153 ++++++++++++++++++++++++++------- dkist/wcs/tests/test_models.py | 10 +-- 3 files changed, 128 insertions(+), 36 deletions(-) create mode 100644 changelog/370.feature.rst diff --git a/changelog/370.feature.rst b/changelog/370.feature.rst new file mode 100644 index 00000000..17e99862 --- /dev/null +++ b/changelog/370.feature.rst @@ -0,0 +1 @@ +Improve performance of ``VaryingCelestialTransform`` classes by not creating a new transform for every set of parameters but instead update the parameters on a single model. diff --git a/dkist/wcs/models.py b/dkist/wcs/models.py index 27a2957e..bcdb115c 100755 --- a/dkist/wcs/models.py +++ b/dkist/wcs/models.py @@ -50,13 +50,13 @@ def generate_celestial_transform( ---------- crpix The reference pixel (a length two array). - crval - The world coordinate at the reference pixel (a length two array). cdelt The sample interval along the pixel axis pc The rotation matrix for the affine transform. If specifying parameters with units this should have celestial (``u.deg``) units. + crval + The world coordinate at the reference pixel (a length two array). lon_pole The longitude of the celestial pole, defaults to 180 degrees. projection @@ -97,6 +97,53 @@ def generate_celestial_transform( return shift | rot | scale | projection | skyrot +def update_celestial_transform_parameters( + transform, + crpix: Iterable[float] | u.Quantity, + cdelt: Iterable[float] | u.Quantity, + pc: ArrayLike | u.Quantity, + crval: Iterable[float] | u.Quantity, + lon_pole: float | u.Quantity, +) -> CompoundModel: + """ + Update an existing transform with new parameter values. + + .. warning:: + This assumes that the type (quantity vs not quantity) and units of the + new parameters are valid for the model. + + Parameters + ---------- + crpix + The reference pixel (a length two array). + cdelt + The sample interval along the pixel axis + pc + The rotation matrix for the affine transform. If specifying parameters + with units this should have celestial (``u.deg``) units. + crval + The world coordinate at the reference pixel (a length two array). + lon_pole + The longitude of the celestial pole, defaults to 180 degrees. + """ + new_params = [ + -crpix[0], + -crpix[1], + pc, + transform[2].translation.value, + cdelt[0], + cdelt[1], + crval[0], + crval[1], + lon_pole + ] + + for name, val in zip(transform.param_names, new_params): + setattr(transform, name, val) + + return transform + + class BaseVaryingCelestialTransform(Model, ABC): """ Shared components between the forward and reverse varying celestial transforms. @@ -106,6 +153,7 @@ class BaseVaryingCelestialTransform(Model, ABC): standard_broadcasting = False _separable = False _input_units_allow_dimensionless = True + _input_units_strict = True _is_inverse = False crpix = Parameter() @@ -169,6 +217,15 @@ def __init__(self, *args, crval_table=None, pc_table=None, projection=m.Pix2Sky_ if len(self.table_shape) != self.n_inputs-2: raise ValueError(f"This model can only be constructed with a {self.n_inputs-2}-dimensional lookup table.") + self._transform = generate_celestial_transform( + crpix=[0, 0], + cdelt=[1, 1], + pc=np.identity(2), + crval=[0, 0], + lon_pole=180, + projection=projection, + ) + def transform_at_index(self, ind, crpix=None, cdelt=None, lon_pole=None): """ Generate a spatial model based on an index for the pc and crval tables. @@ -176,7 +233,7 @@ def transform_at_index(self, ind, crpix=None, cdelt=None, lon_pole=None): Parameters ---------- zind : int - The index to the lookup table.Q + The index to the lookup table. **kwargs The keyword arguments are optional and if not specified will be read @@ -187,27 +244,60 @@ def transform_at_index(self, ind, crpix=None, cdelt=None, lon_pole=None): `astropy.modeling.CompoundModel` """ - # If we are being called from inside evaluate we can skip the lookup - crpix = crpix if crpix is not None else self.crpix - cdelt = cdelt if cdelt is not None else self.cdelt - lon_pole = lon_pole if lon_pole is not None else self.lon_pole - + # If we are out of bounds of the lookup table return a constant model fill_val = np.nan - if isinstance(crpix, u.Quantity): - fill_val = np.nan * u.deg if (np.array(ind) > np.array(self.table_shape) - 1).any() or (np.array(ind) < 0).any(): return m.Const1D(fill_val) & m.Const1D(fill_val) - return generate_celestial_transform( + # The self._transform is always unitless and always in degrees. + # So we need to strip down the parameters to be in the correct units + + # If we are being called from inside evaluate we can skip the lookup + # but we have to handle both dimensionless and unitful parameters + if crpix is None: + if self.crpix.unit is not None: + crpix = u.Quantity(self.crpix) + else: + crpix = self.crpix.value + if cdelt is None: + if self.cdelt.unit is not None: + cdelt = u.Quantity(self.cdelt) + else: + cdelt = self.cdelt.value + if lon_pole is None: + if self.lon_pole.unit is not None: + lon_pole = u.Quantity(self.lon_pole) + else: + lon_pole = self.lon_pole.value + + if isinstance(self.pc_table, u.Quantity): + pc = self.pc_table[ind].to_value(u.pix) + else: + pc = self.pc_table[ind] + + if isinstance(self.crval_table, u.Quantity): + crval = self.crval_table[ind].to_value(u.deg) + else: + crval = self.crval_table[ind] + + if isinstance(crpix, u.Quantity): + crpix = crpix.to_value(u.pix) + + if isinstance(cdelt, u.Quantity): + cdelt = cdelt.to_value(u.deg / u.pix) + + if isinstance(lon_pole, u.Quantity): + lon_pole = lon_pole.to_value(u.deg) + + return update_celestial_transform_parameters( + self._transform, crpix=crpix, cdelt=cdelt, - pc=self.pc_table[ind], - crval=self.crval_table[ind], + pc=pc, + crval=crval, lon_pole=lon_pole, - projection=self.projection, ) - def _map_transform(self, *arrays, crpix, cdelt, lon_pole, inverse=False): # We need to broadcast the arrays together so they are all the same shape barrays = np.broadcast_arrays(*arrays, subok=True) @@ -216,13 +306,13 @@ def _map_transform(self, *arrays, crpix, cdelt, lon_pole, inverse=False): for barray in barrays[2:]: inds.append(self.sanitize_index(barray)) - # Generate output arrays (ignore units for simplicity) if isinstance(barrays[0], u.Quantity): - x_out = np.empty_like(barrays[0].value) - y_out = np.empty_like(barrays[1].value) - else: - x_out = np.empty_like(barrays[0]) - y_out = np.empty_like(barrays[1]) + # Because we have set input_units_strict to True we can assume that + # all inputs have the correct units for the transform + arrays = [arr.value for arr in barrays] + + x_out = np.empty_like(arrays[0]) + y_out = np.empty_like(arrays[1]) # We now loop over every unique value of z and compute the transform. # This means we make the minimum number of calls possible to the transform. @@ -238,19 +328,20 @@ def _map_transform(self, *arrays, crpix, cdelt, lon_pole, inverse=False): else: mask = masks[0] if inverse: - xx, yy = sct.inverse(barrays[0][mask], barrays[1][mask]) + xx, yy = sct.inverse(arrays[0][mask], arrays[1][mask]) else: - xx, yy = sct(barrays[0][mask], barrays[1][mask]) + xx, yy = sct(arrays[0][mask], arrays[1][mask]) - if isinstance(xx, u.Quantity): - x_out[mask], y_out[mask] = xx.value, yy.value - else: - x_out[mask], y_out[mask] = xx, yy + x_out[mask], y_out[mask] = xx, yy - # Put the units back - if isinstance(xx, u.Quantity): - x_out = x_out << xx.unit - y_out = y_out << yy.unit + # Put the units back if we started with some + if isinstance(barrays[0], u.Quantity): + if self._is_inverse: + x_out = x_out << u.pix + y_out = y_out << u.pix + else: + x_out = x_out << u.deg + y_out = y_out << u.deg return x_out, y_out diff --git a/dkist/wcs/tests/test_models.py b/dkist/wcs/tests/test_models.py index 5677e2b9..f3e4a52b 100755 --- a/dkist/wcs/tests/test_models.py +++ b/dkist/wcs/tests/test_models.py @@ -64,7 +64,7 @@ def test_varying_transform_no_lon_pole_unit(): ) trans5 = vct.transform_at_index(5) assert isinstance(trans5, CompoundModel) - assert u.allclose(trans5.right.lon_pole, 180 * u.deg) + assert u.allclose(trans5.right.lon_pole, 180) def test_varying_transform_pc(): @@ -84,7 +84,7 @@ def test_varying_transform_pc(): # Verify that we have the 5th matrix in the series affine = next(filter(lambda sm: isinstance(sm, m.AffineTransformation2D), trans5.traverse_postorder())) assert isinstance(affine, m.AffineTransformation2D) - assert u.allclose(affine.matrix, varying_matrix_lt[5]) + assert u.allclose(affine.matrix, varying_matrix_lt[5].value) # x.shape=(1,), y.shape=(1,), z.shape=(1,) pixel = (0*u.pix, 0*u.pix, 5*u.pix) world = vct(*pixel) @@ -163,8 +163,8 @@ def test_varying_transform_crval(): # Verify that we have the 2nd crval pair in the series crval1 = trans2.right.lon crval2 = trans2.right.lat - assert u.allclose(crval1, crval_table[2][0]) - assert u.allclose(crval2, crval_table[2][1]) + assert u.allclose(crval1, crval_table[2][0].to_value(u.deg)) + assert u.allclose(crval2, crval_table[2][1].to_value(u.deg)) pixel = (0*u.pix, 0*u.pix, 2*u.pix) world = vct(*pixel) @@ -274,7 +274,7 @@ def test_varying_transform_4d_pc_unitless(): ((np.arange(10) * u.pix, np.arange(10) * u.pix, np.arange(5)[..., None] * u.pix, - np.arange(3)[..., None, None]), (3, 5, 10)), + np.arange(3)[..., None, None] * u.pix), (3, 5, 10)), ]) def test_varying_transform_4d_pc_shapes(pixel, lon_shape): varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 15)] * u.pix