Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of VaryingCelestialTransform #370

Merged
merged 13 commits into from
May 17, 2024
Merged
1 change: 1 addition & 0 deletions changelog/370.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve performance of ``VaryingCelestialTransform`` classes by not creating a new transform for every set of parameters but instead update the parameters on a single model.
153 changes: 122 additions & 31 deletions dkist/wcs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@
----------
crpix
The reference pixel (a length two array).
crval
The world coordinate at the reference pixel (a length two array).
cdelt
The sample interval along the pixel axis
pc
The rotation matrix for the affine transform. If specifying parameters
with units this should have celestial (``u.deg``) units.
crval
The world coordinate at the reference pixel (a length two array).
lon_pole
The longitude of the celestial pole, defaults to 180 degrees.
projection
Expand Down Expand Up @@ -97,6 +97,53 @@
return shift | rot | scale | projection | skyrot


def update_celestial_transform_parameters(
transform,
crpix: Iterable[float] | u.Quantity,
cdelt: Iterable[float] | u.Quantity,
pc: ArrayLike | u.Quantity,
crval: Iterable[float] | u.Quantity,
lon_pole: float | u.Quantity,
) -> CompoundModel:
"""
Update an existing transform with new parameter values.

.. warning::
This assumes that the type (quantity vs not quantity) and units of the
new parameters are valid for the model.

Parameters
----------
crpix
The reference pixel (a length two array).
cdelt
The sample interval along the pixel axis
pc
The rotation matrix for the affine transform. If specifying parameters
with units this should have celestial (``u.deg``) units.
crval
The world coordinate at the reference pixel (a length two array).
lon_pole
The longitude of the celestial pole, defaults to 180 degrees.
"""
new_params = [

Check warning on line 129 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L129

Added line #L129 was not covered by tests
-crpix[0],
-crpix[1],
pc,
transform[2].translation.value,
cdelt[0],
cdelt[1],
crval[0],
crval[1],
lon_pole
]

for name, val in zip(transform.param_names, new_params):
setattr(transform, name, val)

Check warning on line 142 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L141-L142

Added lines #L141 - L142 were not covered by tests

return transform

Check warning on line 144 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L144

Added line #L144 was not covered by tests


class BaseVaryingCelestialTransform(Model, ABC):
"""
Shared components between the forward and reverse varying celestial transforms.
Expand All @@ -106,6 +153,7 @@
standard_broadcasting = False
_separable = False
_input_units_allow_dimensionless = True
_input_units_strict = True
_is_inverse = False

crpix = Parameter()
Expand Down Expand Up @@ -169,14 +217,23 @@
if len(self.table_shape) != self.n_inputs-2:
raise ValueError(f"This model can only be constructed with a {self.n_inputs-2}-dimensional lookup table.")

self._transform = generate_celestial_transform(

Check warning on line 220 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L220

Added line #L220 was not covered by tests
crpix=[0, 0],
cdelt=[1, 1],
pc=np.identity(2),
crval=[0, 0],
lon_pole=180,
projection=projection,
)

def transform_at_index(self, ind, crpix=None, cdelt=None, lon_pole=None):
"""
Generate a spatial model based on an index for the pc and crval tables.

Parameters
----------
zind : int
The index to the lookup table.Q
The index to the lookup table.

**kwargs
The keyword arguments are optional and if not specified will be read
Expand All @@ -187,27 +244,60 @@
`astropy.modeling.CompoundModel`

"""
# If we are being called from inside evaluate we can skip the lookup
crpix = crpix if crpix is not None else self.crpix
cdelt = cdelt if cdelt is not None else self.cdelt
lon_pole = lon_pole if lon_pole is not None else self.lon_pole

# If we are out of bounds of the lookup table return a constant model
fill_val = np.nan
if isinstance(crpix, u.Quantity):
fill_val = np.nan * u.deg
if (np.array(ind) > np.array(self.table_shape) - 1).any() or (np.array(ind) < 0).any():
return m.Const1D(fill_val) & m.Const1D(fill_val)

return generate_celestial_transform(
# The self._transform is always unitless and always in degrees.
# So we need to strip down the parameters to be in the correct units

# If we are being called from inside evaluate we can skip the lookup
# but we have to handle both dimensionless and unitful parameters
if crpix is None:
if self.crpix.unit is not None:
crpix = u.Quantity(self.crpix)

Check warning on line 259 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L257-L259

Added lines #L257 - L259 were not covered by tests
else:
crpix = self.crpix.value
if cdelt is None:
if self.cdelt.unit is not None:
cdelt = u.Quantity(self.cdelt)

Check warning on line 264 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L261-L264

Added lines #L261 - L264 were not covered by tests
else:
cdelt = self.cdelt.value
if lon_pole is None:
if self.lon_pole.unit is not None:
lon_pole = u.Quantity(self.lon_pole)

Check warning on line 269 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L266-L269

Added lines #L266 - L269 were not covered by tests
else:
lon_pole = self.lon_pole.value

Check warning on line 271 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L271

Added line #L271 was not covered by tests

if isinstance(self.pc_table, u.Quantity):
pc = self.pc_table[ind].to_value(u.pix)

Check warning on line 274 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L273-L274

Added lines #L273 - L274 were not covered by tests
else:
pc = self.pc_table[ind]

Check warning on line 276 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L276

Added line #L276 was not covered by tests

if isinstance(self.crval_table, u.Quantity):
crval = self.crval_table[ind].to_value(u.deg)

Check warning on line 279 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L278-L279

Added lines #L278 - L279 were not covered by tests
else:
crval = self.crval_table[ind]

Check warning on line 281 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L281

Added line #L281 was not covered by tests

if isinstance(crpix, u.Quantity):
crpix = crpix.to_value(u.pix)

Check warning on line 284 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L283-L284

Added lines #L283 - L284 were not covered by tests

if isinstance(cdelt, u.Quantity):
cdelt = cdelt.to_value(u.deg / u.pix)

Check warning on line 287 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L286-L287

Added lines #L286 - L287 were not covered by tests

if isinstance(lon_pole, u.Quantity):
lon_pole = lon_pole.to_value(u.deg)

Check warning on line 290 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L289-L290

Added lines #L289 - L290 were not covered by tests

return update_celestial_transform_parameters(

Check warning on line 292 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L292

Added line #L292 was not covered by tests
self._transform,
crpix=crpix,
cdelt=cdelt,
pc=self.pc_table[ind],
crval=self.crval_table[ind],
pc=pc,
crval=crval,
lon_pole=lon_pole,
projection=self.projection,
)


def _map_transform(self, *arrays, crpix, cdelt, lon_pole, inverse=False):
# We need to broadcast the arrays together so they are all the same shape
barrays = np.broadcast_arrays(*arrays, subok=True)
Expand All @@ -216,13 +306,13 @@
for barray in barrays[2:]:
inds.append(self.sanitize_index(barray))

# Generate output arrays (ignore units for simplicity)
if isinstance(barrays[0], u.Quantity):
x_out = np.empty_like(barrays[0].value)
y_out = np.empty_like(barrays[1].value)
else:
x_out = np.empty_like(barrays[0])
y_out = np.empty_like(barrays[1])
# Because we have set input_units_strict to True we can assume that
# all inputs have the correct units for the transform
arrays = [arr.value for arr in barrays]

Check warning on line 312 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L312

Added line #L312 was not covered by tests

x_out = np.empty_like(arrays[0])
y_out = np.empty_like(arrays[1])

Check warning on line 315 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L314-L315

Added lines #L314 - L315 were not covered by tests

# We now loop over every unique value of z and compute the transform.
# This means we make the minimum number of calls possible to the transform.
Expand All @@ -238,19 +328,20 @@
else:
mask = masks[0]
if inverse:
xx, yy = sct.inverse(barrays[0][mask], barrays[1][mask])
xx, yy = sct.inverse(arrays[0][mask], arrays[1][mask])

Check warning on line 331 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L331

Added line #L331 was not covered by tests
else:
xx, yy = sct(barrays[0][mask], barrays[1][mask])
xx, yy = sct(arrays[0][mask], arrays[1][mask])

Check warning on line 333 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L333

Added line #L333 was not covered by tests

if isinstance(xx, u.Quantity):
x_out[mask], y_out[mask] = xx.value, yy.value
else:
x_out[mask], y_out[mask] = xx, yy
x_out[mask], y_out[mask] = xx, yy

Check warning on line 335 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L335

Added line #L335 was not covered by tests

# Put the units back
if isinstance(xx, u.Quantity):
x_out = x_out << xx.unit
y_out = y_out << yy.unit
# Put the units back if we started with some
if isinstance(barrays[0], u.Quantity):
if self._is_inverse:
x_out = x_out << u.pix
y_out = y_out << u.pix

Check warning on line 341 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L338-L341

Added lines #L338 - L341 were not covered by tests
else:
x_out = x_out << u.deg
y_out = y_out << u.deg

Check warning on line 344 in dkist/wcs/models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/models.py#L343-L344

Added lines #L343 - L344 were not covered by tests

return x_out, y_out

Expand Down
10 changes: 5 additions & 5 deletions dkist/wcs/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
)
trans5 = vct.transform_at_index(5)
assert isinstance(trans5, CompoundModel)
assert u.allclose(trans5.right.lon_pole, 180 * u.deg)
assert u.allclose(trans5.right.lon_pole, 180)

Check warning on line 67 in dkist/wcs/tests/test_models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/tests/test_models.py#L67

Added line #L67 was not covered by tests


def test_varying_transform_pc():
Expand All @@ -84,7 +84,7 @@
# Verify that we have the 5th matrix in the series
affine = next(filter(lambda sm: isinstance(sm, m.AffineTransformation2D), trans5.traverse_postorder()))
assert isinstance(affine, m.AffineTransformation2D)
assert u.allclose(affine.matrix, varying_matrix_lt[5])
assert u.allclose(affine.matrix, varying_matrix_lt[5].value)

Check warning on line 87 in dkist/wcs/tests/test_models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/tests/test_models.py#L87

Added line #L87 was not covered by tests
# x.shape=(1,), y.shape=(1,), z.shape=(1,)
pixel = (0*u.pix, 0*u.pix, 5*u.pix)
world = vct(*pixel)
Expand Down Expand Up @@ -163,8 +163,8 @@
# Verify that we have the 2nd crval pair in the series
crval1 = trans2.right.lon
crval2 = trans2.right.lat
assert u.allclose(crval1, crval_table[2][0])
assert u.allclose(crval2, crval_table[2][1])
assert u.allclose(crval1, crval_table[2][0].to_value(u.deg))
assert u.allclose(crval2, crval_table[2][1].to_value(u.deg))

Check warning on line 167 in dkist/wcs/tests/test_models.py

View check run for this annotation

Codecov / codecov/patch

dkist/wcs/tests/test_models.py#L166-L167

Added lines #L166 - L167 were not covered by tests

pixel = (0*u.pix, 0*u.pix, 2*u.pix)
world = vct(*pixel)
Expand Down Expand Up @@ -274,7 +274,7 @@
((np.arange(10) * u.pix,
np.arange(10) * u.pix,
np.arange(5)[..., None] * u.pix,
np.arange(3)[..., None, None]), (3, 5, 10)),
np.arange(3)[..., None, None] * u.pix), (3, 5, 10)),
])
def test_varying_transform_4d_pc_shapes(pixel, lon_shape):
varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 15)] * u.pix
Expand Down
Loading