Skip to content

Commit

Permalink
move skycell wcs reading into resample
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram committed Feb 25, 2025
1 parent d2ff8de commit ffad9ac
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 199 deletions.
134 changes: 134 additions & 0 deletions romancal/patch_match/patch_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@
import logging
import os
import os.path
import re

import asdf
import gwcs.wcs as wcs
import numpy as np
import spherical_geometry.polygon as sgp
import spherical_geometry.vector as sgv
from astropy import coordinates
from astropy import units as u
from astropy.modeling import models
from gwcs import WCS, coordinate_frames
from spherical_geometry.vector import normalize_vector
from stcal.alignment import util as wcs_util

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -252,3 +258,131 @@ def veccoords_to_tangent_plane(vertices, tangent_point_vec):
x_coords = np.dot(x_axis, avertices) * RAD_TO_ARCSEC
y_coords = np.dot(y_axis, avertices) * RAD_TO_ARCSEC
return x_coords, y_coords


def wcsinfo_to_wcs(wcsinfo, bounding_box=None, name="wcsinfo"):
"""Create a GWCS from the L3 wcsinfo meta
Parameters
----------
wcsinfo : dict or MosaicModel.meta.wcsinfo
The L3 wcsinfo to create a GWCS from.
bounding_box : None or 4-tuple
The bounding box in detector/pixel space. Form of input is:
(x_left, x_right, y_bottom, y_top)
name : str
Value of the `name` attribute of the GWCS object.
Returns
-------
wcs : wcs.GWCS
The GWCS object created.
"""
pixelshift = models.Shift(-wcsinfo["x_ref"], name="crpix1") & models.Shift(
-wcsinfo["y_ref"], name="crpix2"
)
pixelscale = models.Scale(wcsinfo["pixel_scale"], name="cdelt1") & models.Scale(
wcsinfo["pixel_scale"], name="cdelt2"
)
tangent_projection = models.Pix2Sky_TAN()
celestial_rotation = models.RotateNative2Celestial(
wcsinfo["ra_ref"], wcsinfo["dec_ref"], 180.0
)

matrix = wcsinfo.get("rotation_matrix", None)
if matrix:
matrix = np.array(matrix)
else:
orientat = wcsinfo.get("orientat", 0.0)
matrix = wcs_util.calc_rotation_matrix(
np.deg2rad(orientat), v3i_yangle=0.0, vparity=1
)
matrix = np.reshape(matrix, (2, 2))
rotation = models.AffineTransformation2D(matrix, name="pc_rotation_matrix")
det2sky = (
pixelshift | rotation | pixelscale | tangent_projection | celestial_rotation
)

detector_frame = coordinate_frames.Frame2D(
name="detector", axes_names=("x", "y"), unit=(u.pix, u.pix)
)
sky_frame = coordinate_frames.CelestialFrame(
reference_frame=coordinates.ICRS(), name="icrs", unit=(u.deg, u.deg)
)
wcsobj = WCS([(detector_frame, det2sky), (sky_frame, None)], name=name)

if bounding_box:
wcsobj.bounding_box = bounding_box

return wcsobj


def skycell_to_wcs(skycell_record):
"""From a skycell record, generate a GWCS
Parameters
----------
skycell_record : dict
A skycell record, or row, from the skycell patches table.
Returns
-------
wcsobj : wcs.GWCS
The GWCS object from the skycell record.
"""
wcsinfo = dict()

# The scale is given in arcseconds per pixel. Convert to degrees.
wcsinfo["pixel_scale"] = float(skycell_record["pixel_scale"]) / 3600.0

# Remaining components of the wcsinfo block
wcsinfo["ra_ref"] = float(skycell_record["ra_projection_center"])
wcsinfo["dec_ref"] = float(skycell_record["dec_projection_center"])
wcsinfo["x_ref"] = float(skycell_record["x0_projection"])
wcsinfo["y_ref"] = float(skycell_record["y0_projection"])
wcsinfo["orientat"] = float(skycell_record["orientat_projection_center"])
wcsinfo["rotation_matrix"] = None

# Bounding box of the skycell. Note that the center of the pixels are at (0.5, 0.5)
bounding_box = (
(-0.5, -0.5 + skycell_record["nx"]),
(-0.5, -0.5 + skycell_record["ny"]),
)

wcsobj = wcsinfo_to_wcs(wcsinfo, bounding_box=bounding_box)

wcsobj.array_shape = tuple(
int(axs[1] - axs[0] + 0.5)
for axs in wcsobj.bounding_box.bounding_box(order="C")
)
return wcsobj


def to_skycell_wcs(input):
# check to see if the product name contains a skycell name & if true get the skycell record
if "target" not in input.asn:
return None
skycell_name = input.asn["target"]

if not re.match(r"r\d{3}\w{2}\d{2}x\d{2}y\d{2}", skycell_name):
return None

if "skycell_wcs_info" in input.asn and input.asn["skycell_wcs_info"] != "none":
skycell_record = input.asn["skycell_wcs_info"]
else:
if PATCH_TABLE is None:
load_patch_table()
skycell_record = PATCH_TABLE[
np.where(PATCH_TABLE["name"][:] == skycell_name)[0][0]
]
log.info("Skycell record %s:", skycell_record)

# extract the wcs info from the record for skycell_to_wcs
log.info(
"Creating skycell image at ra: %f dec %f",
float(skycell_record["ra_center"]),
float(skycell_record["dec_center"]),
)
return skycell_to_wcs(skycell_record)
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Unit tests for the mosaic pipeline"""
"""Unit tests for skycell wcs functions"""

import numpy as np

import romancal.pipeline.mosaic_pipeline as mp
from romancal.patch_match.patch_match import skycell_to_wcs, wcsinfo_to_wcs


def test_skycell_to_wcs():
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_skycell_to_wcs():
],
)

wcs = mp.skycell_to_wcs(skycell)
wcs = skycell_to_wcs(skycell)

assert np.allclose(
wcs(
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_wcsinfo_to_wcs():
"orientat": 359.8466793994546,
}

wcs = mp.wcsinfo_to_wcs(wcsinfo)
wcs = wcsinfo_to_wcs(wcsinfo)

assert np.allclose(
wcs(wcsinfo["x_ref"], wcsinfo["y_ref"]), (wcsinfo["ra_ref"], wcsinfo["dec_ref"])
Expand Down
Loading

0 comments on commit ffad9ac

Please sign in to comment.