From cdbaa77a642f7a90abcdf25186f5157c89969846 Mon Sep 17 00:00:00 2001 From: "Brett M. Morris" Date: Thu, 22 Feb 2024 10:18:14 -0500 Subject: [PATCH] handle ndim>2 in maker utils --- src/roman_datamodels/maker_utils/_datamodels.py | 12 ++++++++---- tests/test_models.py | 7 +++++++ tests/test_open.py | 3 --- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/roman_datamodels/maker_utils/_datamodels.py b/src/roman_datamodels/maker_utils/_datamodels.py index 9491b139a..fc28b23be 100644 --- a/src/roman_datamodels/maker_utils/_datamodels.py +++ b/src/roman_datamodels/maker_utils/_datamodels.py @@ -450,8 +450,6 @@ def mk_source_catalog(*, filepath=None, **kwargs): source_catalog = stnode.SourceCatalog() source_catalog["source_catalog"] = kwargs.get("source_catalog", Table([range(3), range(3)], names=["a", "b"])) - source_catalog["meta"] = mk_common_meta() - source_catalog["meta"].update(kwargs.get("meta", dict(segmentation_map=''))) return save_node(source_catalog, filepath=filepath) @@ -470,11 +468,17 @@ def mk_segmentation_map(*, filepath=None, shape=(4096, 4096), **kwargs): ------- roman_datamodels.stnode.SegmentationMap """ - segmentation_map = stnode.SegmentationMap() + if len(shape) > 2: + shape = shape[1:3] + + warnings.warn( + f"{MESSAGE} assuming the first entry is n_groups followed by y, x. The remaining is thrown out!", UserWarning + ) + segmentation_map = stnode.SegmentationMap() segmentation_map["data"] = kwargs.get("data", np.zeros(shape, dtype=np.uint32)) segmentation_map["meta"] = mk_common_meta() - segmentation_map["meta"].update(kwargs.get("meta", dict(filename=''))) + segmentation_map["meta"].update(kwargs.get("meta", {})) return save_node(segmentation_map, filepath=filepath) diff --git a/tests/test_models.py b/tests/test_models.py index ff8545ccd..cade6555d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -715,6 +715,13 @@ def test_make_source_catalog(): assert isinstance(source_catalog_model.source_catalog, Table) +def test_make_segmentation_map(): + segmentation_map = utils.mk_segmentation_map() + segmentation_map_model = datamodels.SegmentationMapModel(segmentation_map) + + assert isinstance(segmentation_map_model.data, np.ndarray) + + def test_datamodel_info_search(capsys): wfi_science_raw = utils.mk_level1_science_raw(shape=(2, 8, 8)) af = asdf.AsdfFile() diff --git a/tests/test_open.py b/tests/test_open.py index 1e120ee2d..ecce57a27 100644 --- a/tests/test_open.py +++ b/tests/test_open.py @@ -226,9 +226,6 @@ def test_node_round_trip(tmp_path, node_class): @pytest.mark.filterwarnings("ignore:This function assumes shape is 2D") @pytest.mark.filterwarnings("ignore:Input shape must be 5D") def test_opening_model(tmp_path, node_class): - if node_class == stnode.SourceCatalog: - pytest.xfail("SourceCatalog does not have a meta attribute yet") - file_path = tmp_path / "test.asdf" # Create a node and write it to disk