Skip to content

Commit

Permalink
Address comments 2.
Browse files Browse the repository at this point in the history
  • Loading branch information
mairanteodoro committed Sep 17, 2024
1 parent 089bdc3 commit 64eb83a
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 52 deletions.
72 changes: 68 additions & 4 deletions romancal/tweakreg/tests/test_tweakreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from roman_datamodels import datamodels as rdm
from roman_datamodels import maker_utils
from stcal.tweakreg.astrometric_utils import get_catalog
from romancal.tweakreg.tweakreg_step import _validate_catalog_columns

from romancal.datamodels import ModelLibrary
from romancal.tweakreg import tweakreg_step as trs
Expand Down Expand Up @@ -638,15 +639,18 @@ def test_tweakreg_updates_group_id(tmp_path, base_image):
)
def test_tweakreg_save_valid_abs_refcat(tmp_path, abs_refcat, request):
"""Test that TweakReg saves the catalog used for absolute astrometry."""
os.chdir(tmp_path)

img = request.getfixturevalue("base_image")(shift_1=1000, shift_2=1000)
catalog_filename = "ref_catalog.ecsv"
abs_refcat_filename = f"fit_{abs_refcat.lower()}_ref.ecsv"
add_tweakreg_catalog_attribute(tmp_path, img, catalog_filename=catalog_filename)

trs.TweakRegStep.call(
[img], save_abs_catalog=True, abs_refcat=abs_refcat, catalog_path=str(tmp_path)
[img],
save_abs_catalog=True,
abs_refcat=abs_refcat,
catalog_path=str(tmp_path),
output_dir=str(tmp_path),
)

assert os.path.exists(tmp_path / abs_refcat_filename)
Expand All @@ -658,15 +662,18 @@ def test_tweakreg_save_valid_abs_refcat(tmp_path, abs_refcat, request):
)
def test_tweakreg_defaults_to_valid_abs_refcat(tmp_path, abs_refcat, request):
"""Test that TweakReg defaults to DEFAULT_ABS_REFCAT on invalid values."""
os.chdir(tmp_path)

img = request.getfixturevalue("base_image")(shift_1=1000, shift_2=1000)
catalog_filename = "ref_catalog.ecsv"
abs_refcat_filename = f"fit_{trs.DEFAULT_ABS_REFCAT.lower()}_ref.ecsv"
add_tweakreg_catalog_attribute(tmp_path, img, catalog_filename=catalog_filename)

trs.TweakRegStep.call(
[img], save_abs_catalog=True, abs_refcat=abs_refcat, catalog_path=str(tmp_path)
[img],
save_abs_catalog=True,
abs_refcat=abs_refcat,
catalog_path=str(tmp_path),
output_dir=str(tmp_path),
)

assert os.path.exists(tmp_path / abs_refcat_filename)
Expand Down Expand Up @@ -1004,3 +1011,60 @@ def test_tweakreg_skips_invalid_exposure_types(exposure_type, tmp_path, base_ima
assert hasattr(model.meta.cal_step, "tweakreg")
assert model.meta.cal_step.tweakreg == "SKIPPED"
res.shelve(model, i, modify=False)


@pytest.mark.parametrize(
"catalog_data, expected_colnames, raises_exception",
[
# both 'x' and 'y' columns present
({"x": [1, 2, 3], "y": [4, 5, 6]}, ["x", "y"], False),
# 'xcentroid' and 'ycentroid' columns present, should be renamed
({"xcentroid": [1, 2, 3], "ycentroid": [4, 5, 6]}, ["x", "y"], False),
# 'x' present, 'ycentroid' present, should rename 'ycentroid' to 'y'
({"x": [1, 2, 3], "ycentroid": [4, 5, 6]}, ["x", "y"], False),
# 'xcentroid' present, 'y' present, should rename 'xcentroid' to 'x'
({"xcentroid": [1, 2, 3], "y": [4, 5, 6]}, ["x", "y"], False),
# neither 'x' nor 'xcentroid' present
({"y": [4, 5, 6]}, None, True),
# neither 'y' nor 'ycentroid' present
({"x": [1, 2, 3]}, None, True),
# no relevant columns present
(
{"a": [1, 2, 3], "b": [4, 5, 6]},
None,
True,
),
],
)
def test_validate_catalog_columns(catalog_data, expected_colnames, raises_exception):
"""Test that TweakRegStep._validate_catalog_columns() correctly validates the
presence of required columns ('x' and 'y') in the provided catalog."""
catalog = table.Table(catalog_data)
if raises_exception:
with pytest.raises(ValueError):
_validate_catalog_columns(catalog)
else:
_validate_catalog_columns(catalog)
assert set(catalog.colnames) == set(expected_colnames)


def test_tweakreg_handles_mixed_exposure_types(tmp_path, base_image):
"""Test that TweakReg can handle mixed exposure types
(non-WFI_IMAGE data will be marked as SKIPPED only and won't be processed)."""
img1 = base_image(shift_1=1000, shift_2=1000)
add_tweakreg_catalog_attribute(tmp_path, img1, catalog_filename="img1")
img1.meta.exposure.type = "WFI_IMAGE"

img2 = base_image(shift_1=1000, shift_2=1000)
add_tweakreg_catalog_attribute(tmp_path, img2, catalog_filename="img2")
img2.meta.exposure.type = "WFI_IMAGE"

img3 = base_image(shift_1=1000, shift_2=1000)
img3.meta.exposure.type = "WFI_GRISM"

res = trs.TweakRegStep.call([img1, img2, img3])

assert len(res) == 3
assert img1.meta.cal_step.tweakreg == "COMPLETE"
assert img2.meta.cal_step.tweakreg == "COMPLETE"
assert img3.meta.cal_step.tweakreg == "SKIPPED"
115 changes: 67 additions & 48 deletions romancal/tweakreg/tweakreg_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,8 @@
from ..stpipe import RomanStep


def _oxford_or_str_join(str_list):
nelem = len(str_list)
if not nelem:
return "N/A"
str_list = list(map(repr, str_list))
if nelem == 1:
return str_list
elif nelem == 2:
return f"{str_list[0]} or {str_list[1]}"
else:
return ", ".join(map(repr, str_list[:-1])) + ", or " + repr(str_list[-1])


SINGLE_GROUP_REFCAT = ["GAIADR3", "GAIADR2", "GAIADR1"]
_SINGLE_GROUP_REFCAT_STR = _oxford_or_str_join(SINGLE_GROUP_REFCAT)
SINGLE_GROUP_REFCAT = tweakreg.SINGLE_GROUP_REFCAT
_SINGLE_GROUP_REFCAT_STR = tweakreg._SINGLE_GROUP_REFCAT_STR
DEFAULT_ABS_REFCAT = SINGLE_GROUP_REFCAT[0]

__all__ = ["TweakRegStep"]
Expand Down Expand Up @@ -99,11 +86,12 @@ def process(self, input):
if not images:
raise ValueError("Input must contain at least one image model.")

self.log.info("")
self.log.info(
f"Number of image groups to be aligned: {len(images.group_indices):d}."
)
self.log.info("Image groups:")
for name in images.group_names:
self.log.info(f" {name}")
# set the first image as reference
with images:
ref_image = images.borrow(0)
Expand Down Expand Up @@ -174,31 +162,21 @@ def process(self, input):

try:
catalog = self.get_tweakreg_catalog(
source_detection, image_model, i
source_detection, image_model
)
except AttributeError as e:
self.log.error(f"Failed to retrieve tweakreg_catalog: {e}")
images.shelve(image_model, i, modify=False)
raise AttributeError() from e
raise e

try:
for axis in ["x", "y"]:
# validate catalog columns
if axis not in catalog.colnames:
long_axis = f"{axis}centroid"
if long_axis in catalog.colnames:
catalog.rename_column(long_axis, axis)
else:
raise ValueError(
"'tweakreg' source catalogs must contain a header with "
"columns named either 'x' and 'y' or 'xcentroid' and 'ycentroid'."
)
# validate catalog columns
_validate_catalog_columns(catalog)
except ValueError as e:
self.log.error(f"Failed to validate catalog columns: {e}")
images.shelve(image_model, i, modify=False)
raise ValueError() from e
raise e

filename = image_model.meta.filename
catalog = tweakreg.filter_catalog_by_bounding_box(
catalog, image_model.meta.wcs.bounding_box
)
Expand All @@ -214,9 +192,9 @@ def process(self, input):
image_model.meta["tweakreg_catalog"] = catalog.as_array()
nsources = len(catalog)
self.log.info(
f"Detected {nsources} sources in {filename}."
f"Detected {nsources} sources in {image_model.meta.filename}."
if nsources
else f"No sources found in {filename}."
else f"No sources found in {image_model.meta.filename}."
)
# build image catalog
# catalog name
Expand All @@ -228,27 +206,33 @@ def process(self, input):
catalog_table.meta["name"] = catalog_name

imcats.append(
tweakreg.construct_wcs_corrector(
wcs=image_model.meta.wcs,
refang=image_model.meta.wcsinfo,
catalog=catalog_table,
group_id=image_model.meta.group_id,
)
{
"model_index": i,
"imcat": tweakreg.construct_wcs_corrector(
wcs=image_model.meta.wcs,
refang=image_model.meta.wcsinfo,
catalog=catalog_table,
group_id=image_model.meta.group_id,
),
}
)
images.shelve(image_model, i)

# run alignment only if it was possible to build image catalogs
if len(imcats):
if getattr(images, "group_indices", None) and len(images.group_indices) > 1:
self.do_relative_alignment(imcats)
# extract WCS correctors to use for image alignment
correctors = [x["imcat"] for x in imcats]
if len(images.group_indices) > 1:
self.do_relative_alignment(correctors)

if self.abs_refcat in SINGLE_GROUP_REFCAT:
self.do_absolute_alignment(ref_image, imcats)
self.do_absolute_alignment(ref_image, correctors)

# finalize step
with images:
for i, imcat in enumerate(imcats):
image_model = images.borrow(i)
for item in imcats:
imcat = item["imcat"]
image_model = images.borrow(item["model_index"])
image_model.meta.cal_step["tweakreg"] = "COMPLETE"
# remove source catalog
del image_model.meta["tweakreg_catalog"]
Expand Down Expand Up @@ -293,7 +277,7 @@ def process(self, input):
del image_model.meta["wcs_fit_results"][k]

image_model.meta.wcs = imcat.wcs
images.shelve(image_model, i)
images.shelve(image_model, item["model_index"])

return images

Expand Down Expand Up @@ -327,7 +311,7 @@ def read_catalog(self, catalog_name):
catalog = Table.read(catalog_name, format=self.catalog_format)
return catalog

def get_tweakreg_catalog(self, source_detection, image_model, index):
def get_tweakreg_catalog(self, source_detection, image_model):
"""
Retrieve the tweakreg catalog from source detection.
Expand All @@ -341,8 +325,6 @@ def get_tweakreg_catalog(self, source_detection, image_model, index):
The source detection metadata containing catalog information.
image_model : DataModel
The image model associated with the source detection.
index : int
The index of the image model in the collection.
Returns
-------
Expand Down Expand Up @@ -489,3 +471,40 @@ def _parse_catfile(catfile):
raise ValueError("'catfile' can contain at most two columns.")

return catdict


def _validate_catalog_columns(catalog):
"""
Validate the presence of required columns in the catalog.
This method checks if the specified axis column exists in the catalog.
If the axis is not found, it looks for a corresponding centroid column
and renames it if present. If neither is found, it raises an error.
Parameters
----------
catalog : Table
The catalog to validate, which should contain source information.
axis : str
The axis to check for in the catalog (e.g., 'x' or 'y').
Returns
-------
None
Raises
------
ValueError
If the required columns are missing from the catalog.
"""
for axis in ["x", "y"]:
if axis not in catalog.colnames:
long_axis = f"{axis}centroid"
if long_axis in catalog.colnames:
catalog.rename_column(long_axis, axis)
else:
raise ValueError(
"'tweakreg' source catalogs must contain a header with "
"columns named either 'x' and 'y' or 'xcentroid' and 'ycentroid'."
)
return catalog

0 comments on commit 64eb83a

Please sign in to comment.