From 1903ac7a31e6cd51353b0fe8882834cdc777faf9 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Thu, 13 Feb 2025 12:57:24 -0500 Subject: [PATCH 01/14] Add unit test demonstrating the bug --- gwcs/tests/test_wcs.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index 37918bbb..e68fce7d 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -1813,3 +1813,25 @@ def test_direct_numerical_inverse(gwcs_romanisim): out = gwcs_romanisim.numerical_inverse(*ra_dec) assert_allclose(xy, out) + + +def test_array_high_level_output(): + """ + Test that we don't loose array values when requesting a high-level output + from a WCS object. + """ + input_frame = cf.CoordinateFrame( + naxes=1, + axes_type=("SPATIAL",), + axes_order=(0,), + name="pixels", + unit=(u.pix,), + axes_names=("x",), + ) + output_frame = cf.SpectralFrame(unit=(u.nm,), axes_names=("lambda",)) + wave_model = models.Scale(0.1) | models.Shift(500) + gwcs = wcs.WCS([(input_frame, wave_model), (output_frame, None)]) + assert ( + gwcs(np.array([0, 1, 2]), with_units=True) + == coord.SpectralCoord([500, 500.1, 500.2] * u.nm) + ).all() From 331c708c03e9eed43771a127b0f7690123d8f4f7 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Thu, 13 Feb 2025 13:20:19 -0500 Subject: [PATCH 02/14] Bugfix for the loss of array values This was largely an issue of the zip(strict=False) causing missed entries. I made this strict=True --- gwcs/coordinate_frames.py | 23 +++++++++++++++++++++++ gwcs/wcs/_wcs.py | 20 ++++++-------------- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py index 62519b58..142d4ba8 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames.py @@ -388,6 +388,29 @@ def _native_world_axis_object_components(self): to be able to get the components in their native order. """ + def add_units(self, arrays: u.Quantity | np.ndarray | float) -> tuple[u.Quantity]: + """ + Add units to the arrays + """ + return tuple( + u.Quantity(array, unit=unit) + for array, unit in zip(arrays, self.unit, strict=True) + ) + + def remove_units( + self, arrays: u.Quantity | np.ndarray | float + ) -> tuple[np.ndarray]: + """ + Remove units from the input arrays + """ + if self.naxes == 1: + arrays = (arrays,) + + return tuple( + array.to_value(unit) if isinstance(array, u.Quantity) else array + for array, unit in zip(arrays, self.unit, strict=True) + ) + class CoordinateFrame(BaseCoordinateFrame): """ diff --git a/gwcs/wcs/_wcs.py b/gwcs/wcs/_wcs.py index 51a5790f..24d6aa79 100644 --- a/gwcs/wcs/_wcs.py +++ b/gwcs/wcs/_wcs.py @@ -6,7 +6,6 @@ import astropy.units as u import numpy as np -from astropy import utils as astutil from astropy.io import fits from astropy.modeling import fix_inputs, projections from astropy.modeling.bounding_box import ModelBoundingBox as Bbox @@ -117,13 +116,10 @@ def __init__( self._pixel_shape = None def _add_units_input( - self, arrays: list[np.ndarray], frame: CoordinateFrame | None + self, arrays: np.ndarray | float, frame: CoordinateFrame | None ) -> tuple[u.Quantity, ...]: if frame is not None: - return tuple( - u.Quantity(array, unit) - for array, unit in zip(arrays, frame.unit, strict=False) - ) + return frame.add_units(arrays) return arrays @@ -131,10 +127,7 @@ def _remove_units_input( self, arrays: list[u.Quantity], frame: CoordinateFrame | None ) -> tuple[np.ndarray, ...]: if frame is not None: - return tuple( - array.to_value(unit) if isinstance(array, u.Quantity) else array - for array, unit in zip(arrays, frame.unit, strict=False) - ) + return frame.remove_units(arrays) return arrays @@ -166,10 +159,7 @@ def __call__( results = self._call_forward( *args, with_bounding_box=with_bounding_box, fill_value=fill_value, **kwargs ) - if with_units: - if not astutil.isiterable(results): - results = (results,) # values are always expected to be arrays or scalars not quantities results = self._remove_units_input(results, self.output_frame) high_level = values_to_high_level_objects(*results, low_level_wcs=self) @@ -403,7 +393,9 @@ def outside_footprint(self, world_arrays): max_ax = axis_range[~m].min() outside = (coord > min_ax) & (coord < max_ax) else: - coord_ = self._remove_units_input([coord], self.output_frame)[0] + coord_ = self._remove_quantity_output( + world_arrays, self.output_frame + )[idim] outside = (coord_ < min_ax) | (coord_ > max_ax) if np.any(outside): if np.isscalar(coord): From b1df72300fbab603575f55e65fc31aa5d240d099 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Thu, 13 Feb 2025 13:32:27 -0500 Subject: [PATCH 03/14] Update changes --- CHANGES.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index 477d87f9..081ce985 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,6 +3,9 @@ - Fix API issue with ``wcs.numerical_inverse``. [#565] +- Fix bug where "vector" (shape (n,) not shape (1, n)) arrays would loose all their entries except the + first if ``with_units=True`` was used. [#563] + 0.24.0 (2025-02-04) ------------------- From 9feb67cc8723ece5618fe0f74f9be66a7dddd172 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Fri, 14 Feb 2025 09:53:01 -0500 Subject: [PATCH 04/14] Order clockwise does not always apply to a bounding box. Move units strip to directly on the bounding box --- gwcs/wcs/_wcs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gwcs/wcs/_wcs.py b/gwcs/wcs/_wcs.py index 24d6aa79..1f919d3a 100644 --- a/gwcs/wcs/_wcs.py +++ b/gwcs/wcs/_wcs.py @@ -1194,7 +1194,6 @@ def footprint(self, bounding_box=None, center=False, axis_type="all"): """ def _order_clockwise(v): - v = [self._remove_units_input(vv, self.input_frame) for vv in v] return np.asarray( [ [v[0][0], v[1][0]], @@ -1218,7 +1217,9 @@ def _order_clockwise(v): bb = np.asarray([b.value for b in bb]) * bb[0].unit vertices = (bb,) elif all_spatial: - vertices = _order_clockwise(bb) + vertices = _order_clockwise( + [self._remove_units_input(b, self.input_frame) for b in bb] + ) else: vertices = np.array(list(itertools.product(*bb))).T From da878c4b7157b52d611903df6b49d85dad983af8 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Wed, 12 Feb 2025 14:34:53 -0500 Subject: [PATCH 05/14] Add test for bug --- gwcs/tests/test_wcs.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index e68fce7d..53ba5b49 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -1835,3 +1835,36 @@ def test_array_high_level_output(): gwcs(np.array([0, 1, 2]), with_units=True) == coord.SpectralCoord([500, 500.1, 500.2] * u.nm) ).all() + + +def test_parameterless_transform(): + """ + Test that a transform with no parameters correctly handles units. + -> The wcs does not introduce units when evaluating the forward or backward + transform for models with no parameters + Regression test for #558 + """ + + in_frame = cf.Frame2D(name="in_frame") + out_frame = cf.Frame2D(name="out_frame") + + gwcs = wcs.WCS( + [ + (in_frame, models.Identity(2)), + (out_frame, None), + ] + ) + + # The expectation for this wcs is that: + # - gwcs(1, 1) has no units + # (__call__ apparently is supposed to pass units through?) + # - gwcs(1*u.pix, 1*u.pix) has units + # - gwcs.invert(1, 1) has no units + # - gwcs.invert(1*u.pix, 1*u.pix) has no units + + # No units introduced by the forward transform + assert gwcs(1, 1) == (1, 1) + assert gwcs(1 * u.pix, 1 * u.pix) == (1 * u.pix, 1 * u.pix) + + assert gwcs.invert(1, 1) == (1, 1) + assert gwcs.invert(1 * u.pix, 1 * u.pix) == (1, 1) From 84b479eebb9992dd77e18e82e08407f8548d98d2 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Wed, 12 Feb 2025 14:37:14 -0500 Subject: [PATCH 06/14] Bugfix for the reported bug This fixes the bug so that gwcs works as it currently states. Its also not perfect as the issue lies up in astropy.modeling not gwcs itself. --- gwcs/wcs/_wcs.py | 108 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 91 insertions(+), 17 deletions(-) diff --git a/gwcs/wcs/_wcs.py b/gwcs/wcs/_wcs.py index 1f919d3a..bf3fc067 100644 --- a/gwcs/wcs/_wcs.py +++ b/gwcs/wcs/_wcs.py @@ -168,6 +168,85 @@ def __call__( return high_level return results + def _evaluate_transform( + self, + transform, + from_frame, + to_frame, + *args, + with_bounding_box: bool = True, + fill_value: float | np.number = np.nan, + **kwargs, + ): + """ + Introduces or removes units from the arguments as need so that the transform + can be successfully evaluated. + + Notes + ----- + Much of the logic in this method is due to the unfortunate fact that the + `uses_quantity` property for models is not reliable for determining if one + must pass quantities or not. It instead tells you: + 1. If it has any parameter that is a quantity + 2. It defaults to true for parameterless models. + + This is problematic because its entirely possible to construct a model with + a parameter that is a quantity but the model itself either doesn't require + them or in fact cannot use them. This is a very rare case but it could happen. + Currently, this case is not handled, but it is worth noting in case it comes up + + The more problematic case is for parameterless models. `uses_quantity` assumes + that if there are no parameters, then the model is agnostic to quantity inputs. + This is an incorrect assumption, even with in `astropy.modeling`'s built in + models. The `Tabular1D` model for example has no "parameters" but it can + require quantities if its "points" construction input is a quantity. This + is the main case for the try/except block in this method. + + Properly dealing with this will require upstream work in `astropy.modeling` + which is outside the scope of what GWCS can control. + + to_frame is included as we really ought to be stripping the result of units + but we currently are not. API refactor should include addressing this. + """ + + # Validate that the input type matches what the transform expects + input_is_quantity = any(isinstance(a, u.Quantity) for a in args) + + def _transform(*args): + """Wrap the transform evaluation""" + + return transform( + *args, + with_bounding_box=with_bounding_box, + fill_value=fill_value, + **kwargs, + ) + + # Models with no parameters claim they use quantities but this may incorrectly + # introduce units so we don't at first + if ( + not input_is_quantity + and transform.uses_quantity + and transform.parameters.size + ): + args = self._add_units_input(args, from_frame) + if not transform.uses_quantity and input_is_quantity: + args = self._remove_units_input(args, from_frame) + + try: + return _transform(*args) + except u.UnitsError: + # In this case we are handling parameterless models that require units + # to function correctly. + if ( + not input_is_quantity + and transform.uses_quantity + and not transform.parameters.size + ): + return _transform(*self._add_units_input(args, from_frame)) + + raise + def _call_forward( self, *args, @@ -193,15 +272,14 @@ def _call_forward( msg = "WCS.forward_transform is not implemented." raise NotImplementedError(msg) - # Validate that the input type matches what the transform expects - input_is_quantity = any(isinstance(a, u.Quantity) for a in args) - if not input_is_quantity and transform.uses_quantity: - args = self._add_units_input(args, from_frame) - if not transform.uses_quantity and input_is_quantity: - args = self._remove_units_input(args, from_frame) - - return transform( - *args, with_bounding_box=with_bounding_box, fill_value=fill_value, **kwargs + return self._evaluate_transform( + transform, + from_frame, + to_frame, + *args, + with_bounding_box=with_bounding_box, + fill_value=fill_value, + **kwargs, ) def in_image(self, *args, **kwargs): @@ -324,16 +402,12 @@ def _call_backward( args = self.outside_footprint(args) if transform is not None: - # Validate that the input type matches what the transform expects - input_is_quantity = any(isinstance(a, u.Quantity) for a in args) - if not input_is_quantity and transform.uses_quantity: - args = self._add_units_input(args, self.output_frame) - if not transform.uses_quantity and input_is_quantity: - args = self._remove_units_input(args, self.output_frame) - # remove iterative inverse-specific keyword arguments: akwargs = {k: v for k, v in kwargs.items() if k not in _ITER_INV_KWARGS} - result = transform( + result = self._evaluate_transform( + transform, + self.output_frame, + self.input_frame, *args, with_bounding_box=with_bounding_box, fill_value=fill_value, From d40560200c9a17290b635fd159f45767a200cc3e Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Thu, 13 Feb 2025 09:52:43 -0500 Subject: [PATCH 07/14] Update changes --- CHANGES.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index 081ce985..e31aa0ae 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,6 +3,9 @@ - Fix API issue with ``wcs.numerical_inverse``. [#565] +- Bugfix for ``__call__`` and ``invert`` incorrectly handling units when involving + "parameterless" transforms. [#562] + - Fix bug where "vector" (shape (n,) not shape (1, n)) arrays would loose all their entries except the first if ``with_units=True`` was used. [#563] From c23afd363f2f8e89df25c57ce38bb3cab19fe33c Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Wed, 8 Jan 2025 16:40:03 -0500 Subject: [PATCH 08/14] Add typing module to collect common type hints This also adds mypy checking in a way to support gradual adoption of type hints. --- .pre-commit-config.yaml | 10 +++++++ gwcs/_typing.py | 57 +++++++++++++++++++++++++++++++++++++++ pyproject.toml | 60 +++++++++++++++++++++++++++++++++++++++-- 3 files changed, 125 insertions(+), 2 deletions(-) create mode 100644 gwcs/_typing.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a0f3981f..e3598fda 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -59,3 +59,13 @@ repos: - id: ruff args: ["--fix", "--show-fixes"] - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: "v1.15.0" + hooks: + - id: mypy + files: gwcs + args: [] + additional_dependencies: # add dependencies for mypy to pull information from + - numpy>=2 + - astropy>=7 diff --git a/gwcs/_typing.py b/gwcs/_typing.py new file mode 100644 index 00000000..2300d163 --- /dev/null +++ b/gwcs/_typing.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from fractions import Fraction +from typing import TypeAlias + +import numpy as np +import numpy.typing as npt +from astropy.coordinates import SkyCoord, SpectralCoord, StokesCoord +from astropy.modeling.bounding_box import CompoundBoundingBox, ModelBoundingBox +from astropy.time import Time +from astropy.units import Quantity + +__all__ = [ + "BoundingBox", + "Bounds", + "HighLevelObject", + "Interval", + "LowLevelArrays", + "LowLevelUnitArrays", + "LowLevelUnitValue", + "LowLevelValue", + "OutputLowLevelArray", + "Real", + "WorldAxisClass", + "WorldAxisComponent", + "WorldAxisComponents", +] + +Real: TypeAlias = int | float | Fraction | np.integer | np.floating + +Interval: TypeAlias = tuple[Real, Real] +Bounds: TypeAlias = tuple[Interval, ...] | None + +BoundingBox: TypeAlias = ModelBoundingBox | CompoundBoundingBox | None + +# This is to represent a single value from a low-level function. +LowLevelValue: TypeAlias = Real | npt.NDArray[np.number] +# Handle when units are a possibility. Not all functions allow units in/out +LowLevelUnitValue: TypeAlias = LowLevelValue | Quantity + +# This is to represent all the values together for a single low-level function. +LowLevelArrays: TypeAlias = tuple[LowLevelValue, ...] +LowLevelUnitArrays: TypeAlias = tuple[LowLevelUnitValue, ...] + +# This is to represent a general array output from a low-level function. +# Due to the fact 1D outputs are returned as a single value, rather than a tuple. +OutputLowLevelArray: TypeAlias = LowLevelValue | LowLevelArrays + +HighLevelObject: TypeAlias = Time | SkyCoord | SpectralCoord | StokesCoord | Quantity + +WorldAxisComponent: TypeAlias = tuple[str, str | int, str] +WorldAxisClass: TypeAlias = tuple[ + type | str, tuple[int | None, ...], dict[str, HighLevelObject] +] + +WorldAxisComponents: TypeAlias = list[WorldAxisComponent] +WorldAxisClasses: TypeAlias = dict[str, WorldAxisClass] diff --git a/pyproject.toml b/pyproject.toml index 961b0bed..859b7164 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "gwcs" description = "Generalized World Coordinate System" -requires-python = ">=3.10" +requires-python = ">=3.11" authors = [ { name = "gwcs developers", email = "help@stsci.edu" }, ] @@ -44,7 +44,6 @@ docs = [ "sphinx-copybutton", "pydata-sphinx-theme", "sphinx-asdf", - "tomli; python_version <'3.11'", ] test = [ "ci-watson>=0.3.0", @@ -175,3 +174,60 @@ ignore = [ "gwcs/converters/tests/*" = ["S101"] "docs/conf.py" = ["INP001", "ERA001"] "convert_schemas.py" = ["PTH"] + + +[tool.mypy] +files = "gwcs" +python_version = "3.11" +strict = true +enable_error_code = [ + "ignore-without-code", + "redundant-expr", + "truthy-bool", +] +exclude = [ + 'gwcs/tests/$', + 'gwcs/converters/tests/$', + 'gwcs/wcs/$', + 'gwcs/api.py', + 'gwcs/coordinate_frames.py', + 'gwcs/examples.py', + 'gwcs/extension.py', + 'gwcs/geometry.py', + 'gwcs/region.py', + 'gwcs/selector.py', + 'gwcs/spectroscopy.py', + 'gwcs/utils.py', + 'gwcs/wcstools.py', +] +warn_unreachable = true + +[[tool.mypy.overrides]] +module = [ + "astropy.coordinates.*", + "astropy.modeling.*", + # "astropy.table.*", + "astropy.time.*", + "astropy.units.*", + # "astropy.utils.*", +] +follow_untyped_imports = true + +# For gradual adoption +[[tool.mypy.overrides]] +module = [ + "gwcs.tests.*", + "gwcs.converters.*", + "gwcs.wcs.*", + "gwcs.api.*", + "gwcs.coordinate_frames.*", + "gwcs.examples.*", + "gwcs.extension.*", + "gwcs.geometry.*", + "gwcs.region.*", + "gwcs.selector.*", + "gwcs.spectroscopy.*", + "gwcs.utils.*", + "gwcs.wcstools.*", +] +ignore_errors = true From 7be9351b695f2dd11068e20ace5b5846aad0bee8 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Wed, 8 Jan 2025 13:54:00 -0500 Subject: [PATCH 09/14] Move coordinate_frames into its own module --- gwcs/coordinate_frames/__init__.py | 25 +++++++++++++++++++ .../_coordinate_frames.py} | 1 + gwcs/tests/test_coordinate_systems.py | 14 +++++++---- 3 files changed, 35 insertions(+), 5 deletions(-) create mode 100644 gwcs/coordinate_frames/__init__.py rename gwcs/{coordinate_frames.py => coordinate_frames/_coordinate_frames.py} (99%) diff --git a/gwcs/coordinate_frames/__init__.py b/gwcs/coordinate_frames/__init__.py new file mode 100644 index 00000000..96f666f6 --- /dev/null +++ b/gwcs/coordinate_frames/__init__.py @@ -0,0 +1,25 @@ +from ._coordinate_frames import ( + BaseCoordinateFrame, + CelestialFrame, + CompositeFrame, + CoordinateFrame, + EmptyFrame, + Frame2D, + SpectralFrame, + StokesFrame, + TemporalFrame, + get_ctype_from_ucd, +) + +__all__ = [ + "BaseCoordinateFrame", + "CelestialFrame", + "CompositeFrame", + "CoordinateFrame", + "EmptyFrame", + "Frame2D", + "SpectralFrame", + "StokesFrame", + "TemporalFrame", + "get_ctype_from_ucd", +] diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames/_coordinate_frames.py similarity index 99% rename from gwcs/coordinate_frames.py rename to gwcs/coordinate_frames/_coordinate_frames.py index 142d4ba8..295f6461 100644 --- a/gwcs/coordinate_frames.py +++ b/gwcs/coordinate_frames/_coordinate_frames.py @@ -148,6 +148,7 @@ "SpectralFrame", "StokesFrame", "TemporalFrame", + "get_ctype_from_ucd", ] diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index fd6cf834..ed8e0ee5 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -16,6 +16,10 @@ from gwcs import WCS from gwcs import coordinate_frames as cf +from gwcs.coordinate_frames._coordinate_frames import ( + _ALLOWED_UCD_DUPLICATES, + _ucd1_to_ctype_name_mapping, +) astropy_version = astropy.__version__ @@ -520,8 +524,8 @@ def test_ucd1_to_ctype_not_out_of_sync(caplog): dictionary with new types defined in ``astropy``'s ``CTYPE_TO_UCD1``. """ - cf._ucd1_to_ctype_name_mapping( - ctype_to_ucd=CTYPE_TO_UCD1, allowed_ucd_duplicates=cf._ALLOWED_UCD_DUPLICATES + _ucd1_to_ctype_name_mapping( + ctype_to_ucd=CTYPE_TO_UCD1, allowed_ucd_duplicates=_ALLOWED_UCD_DUPLICATES ) assert len(caplog.record_tuples) == 0 @@ -536,8 +540,8 @@ def test_ucd1_to_ctype(caplog): ctype_to_ucd = dict(**CTYPE_TO_UCD1, **new_ctype_to_ucd) - inv_map = cf._ucd1_to_ctype_name_mapping( - ctype_to_ucd=ctype_to_ucd, allowed_ucd_duplicates=cf._ALLOWED_UCD_DUPLICATES + inv_map = _ucd1_to_ctype_name_mapping( + ctype_to_ucd=ctype_to_ucd, allowed_ucd_duplicates=_ALLOWED_UCD_DUPLICATES ) assert caplog.record_tuples[-1][1] == logging.WARNING @@ -545,7 +549,7 @@ def test_ucd1_to_ctype(caplog): "Found unsupported duplicate physical type" ) - for k, v in cf._ALLOWED_UCD_DUPLICATES.items(): + for k, v in _ALLOWED_UCD_DUPLICATES.items(): assert inv_map.get(k, "") == v for k, v in inv_map.items(): From 6e3d941930fcda20bcc315c5f609933428f415ef Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Wed, 8 Jan 2025 13:57:27 -0500 Subject: [PATCH 10/14] Factor FrameProperties into its own module --- gwcs/coordinate_frames/__init__.py | 118 ++++++++++++ gwcs/coordinate_frames/_coordinate_frames.py | 190 +------------------ gwcs/coordinate_frames/_properties.py | 74 ++++++++ 3 files changed, 194 insertions(+), 188 deletions(-) create mode 100644 gwcs/coordinate_frames/_properties.py diff --git a/gwcs/coordinate_frames/__init__.py b/gwcs/coordinate_frames/__init__.py index 96f666f6..a6cde3a6 100644 --- a/gwcs/coordinate_frames/__init__.py +++ b/gwcs/coordinate_frames/__init__.py @@ -1,3 +1,121 @@ +""" +This module defines coordinate frames for describing the inputs and/or outputs +of a transform. + +In the block diagram, the WCS pipeline has a two stage transformation (two +astropy Model instances), with an input frame, an output frame and an +intermediate frame. + +.. code-block:: + + ┌───────────────┐ + │ │ + │ Input │ + │ Frame │ + │ │ + └───────┬───────┘ + │ + ┌─────▼─────┐ + │ Transform │ + └─────┬─────┘ + │ + ┌───────▼───────┐ + │ │ + │ Intermediate │ + │ Frame │ + │ │ + └───────┬───────┘ + │ + ┌─────▼─────┐ + │ Transform │ + └─────┬─────┘ + │ + ┌───────▼───────┐ + │ │ + │ Output │ + │ Frame │ + │ │ + └───────────────┘ + + +Each frame instance is both metadata for the inputs/outputs of a transform and +also a converter between those inputs/outputs and richer coordinate +representations of those inputs/outputs. + +For example, an output frame of type `~gwcs.coordinate_frames.SpectralFrame` +provides metadata to the `.WCS` object such as the ``axes_type`` being +``"SPECTRAL"`` and the unit of the output etc. The output frame also provides a +converter of the numeric output of the transform to a +`~astropy.coordinates.SpectralCoord` object, by combining this metadata with the +numerical values. + +``axes_order`` and conversion between objects and arguments +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +One of the key concepts regarding coordinate frames is the ``axes_order`` argument. +This argument is used to map from the components of the frame to the inputs/outputs +of the transform. To illustrate this consider this situation where you have a +forward transform which outputs three coordinates ``[lat, lambda, lon]``. These +would be represented as a `.SpectralFrame` and a `.CelestialFrame`, however, the +axes of a `.CelestialFrame` are always ``[lon, lat]``, so by specifying two +frames as + +.. code-block:: python + + [SpectralFrame(axes_order=(1,)), CelestialFrame(axes_order=(2, 0))] + +we would map the outputs of this transform into the correct positions in the frames. + As shown below, this is also used when constructing the inputs to the inverse + transform. + + +When taking the output from the forward transform the following transformation +is performed by the coordinate frames: + +.. code-block:: + + lat, lambda, lon + │ │ │ + └──────┼─────┼────────┐ + ┌───────────┘ └──┐ │ + │ │ │ + ┌─────────▼────────┐ ┌──────▼─────▼─────┐ + │ │ │ │ + │ SpectralFrame │ │ CelestialFrame │ + │ │ │ │ + │ (1,) │ │ (2, 0) │ + │ │ │ │ + └─────────┬────────┘ └──────────┬────┬──┘ + │ │ │ + │ │ │ + ▼ ▼ ▼ + SpectralCoord(lambda) SkyCoord((lon, lat)) + + +When considering the backward transform the following transformations take place +in the coordinate frames before the transform is called: + +.. code-block:: + + SpectralCoord(lambda) SkyCoord((lon, lat)) + │ │ │ + └─────┐ ┌────────────┘ │ + │ │ ┌────────────┘ + ▼ ▼ ▼ + [lambda, lon, lat] + │ │ │ + │ │ │ + ┌──────▼─────▼────▼────┐ + │ │ + │ Sort by axes_order │ + │ │ + └────┬──────┬─────┬────┘ + │ │ │ + ▼ ▼ ▼ + lat, lambda, lon + +""" + from ._coordinate_frames import ( BaseCoordinateFrame, CelestialFrame, diff --git a/gwcs/coordinate_frames/_coordinate_frames.py b/gwcs/coordinate_frames/_coordinate_frames.py index 295f6461..45aa973f 100644 --- a/gwcs/coordinate_frames/_coordinate_frames.py +++ b/gwcs/coordinate_frames/_coordinate_frames.py @@ -1,134 +1,15 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst -""" -This module defines coordinate frames for describing the inputs and/or outputs -of a transform. - -In the block diagram, the WCS pipeline has a two stage transformation (two -astropy Model instances), with an input frame, an output frame and an -intermediate frame. - -.. code-block:: - - ┌───────────────┐ - │ │ - │ Input │ - │ Frame │ - │ │ - └───────┬───────┘ - │ - ┌─────▼─────┐ - │ Transform │ - └─────┬─────┘ - │ - ┌───────▼───────┐ - │ │ - │ Intermediate │ - │ Frame │ - │ │ - └───────┬───────┘ - │ - ┌─────▼─────┐ - │ Transform │ - └─────┬─────┘ - │ - ┌───────▼───────┐ - │ │ - │ Output │ - │ Frame │ - │ │ - └───────────────┘ - - -Each frame instance is both metadata for the inputs/outputs of a transform and -also a converter between those inputs/outputs and richer coordinate -representations of those inputs/outputs. - -For example, an output frame of type `~gwcs.coordinate_frames.SpectralFrame` -provides metadata to the `.WCS` object such as the ``axes_type`` being -``"SPECTRAL"`` and the unit of the output etc. The output frame also provides a -converter of the numeric output of the transform to a -`~astropy.coordinates.SpectralCoord` object, by combining this metadata with the -numerical values. - -``axes_order`` and conversion between objects and arguments -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -One of the key concepts regarding coordinate frames is the ``axes_order`` argument. -This argument is used to map from the components of the frame to the inputs/outputs -of the transform. To illustrate this consider this situation where you have a -forward transform which outputs three coordinates ``[lat, lambda, lon]``. These -would be represented as a `.SpectralFrame` and a `.CelestialFrame`, however, the -axes of a `.CelestialFrame` are always ``[lon, lat]``, so by specifying two -frames as - -.. code-block:: python - - [SpectralFrame(axes_order=(1,)), CelestialFrame(axes_order=(2, 0))] - -we would map the outputs of this transform into the correct positions in the frames. - As shown below, this is also used when constructing the inputs to the inverse - transform. - - -When taking the output from the forward transform the following transformation -is performed by the coordinate frames: - -.. code-block:: - - lat, lambda, lon - │ │ │ - └──────┼─────┼────────┐ - ┌───────────┘ └──┐ │ - │ │ │ - ┌─────────▼────────┐ ┌──────▼─────▼─────┐ - │ │ │ │ - │ SpectralFrame │ │ CelestialFrame │ - │ │ │ │ - │ (1,) │ │ (2, 0) │ - │ │ │ │ - └─────────┬────────┘ └──────────┬────┬──┘ - │ │ │ - │ │ │ - ▼ ▼ ▼ - SpectralCoord(lambda) SkyCoord((lon, lat)) - - -When considering the backward transform the following transformations take place -in the coordinate frames before the transform is called: - -.. code-block:: - - SpectralCoord(lambda) SkyCoord((lon, lat)) - │ │ │ - └─────┐ ┌────────────┘ │ - │ │ ┌────────────┘ - ▼ ▼ ▼ - [lambda, lon, lat] - │ │ │ - │ │ │ - ┌──────▼─────▼────▼────┐ - │ │ - │ Sort by axes_order │ - │ │ - └────┬──────┬─────┬────┘ - │ │ │ - ▼ ▼ ▼ - lat, lambda, lon - -""" import abc import contextlib import logging import numbers from collections import defaultdict -from dataclasses import InitVar, dataclass import numpy as np from astropy import coordinates as coord from astropy import time from astropy import units as u -from astropy import utils as astutil from astropy.coordinates import StokesCoord from astropy.utils.misc import isiterable from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1 @@ -136,7 +17,8 @@ high_level_objects_to_values, values_to_high_level_objects, ) -from astropy.wcs.wcsapi.low_level_api import VALID_UCDS, validate_physical_types + +from ._properties import FrameProperties __all__ = [ "BaseCoordinateFrame", @@ -205,74 +87,6 @@ def get_ctype_from_ucd(ucd): return UCD1_TO_CTYPE.get(ucd, "") -@dataclass -class FrameProperties: - naxes: InitVar[int] - axes_type: tuple[str] - unit: tuple[u.Unit] = None - axes_names: tuple[str] = None - axis_physical_types: list[str] = None - - def __post_init__(self, naxes): - if isinstance(self.axes_type, str): - self.axes_type = (self.axes_type,) - else: - self.axes_type = tuple(self.axes_type) - - if len(self.axes_type) != naxes: - msg = "Length of axes_type does not match number of axes." - raise ValueError(msg) - - if self.unit is not None: - unit = tuple(self.unit) if astutil.isiterable(self.unit) else (self.unit,) - if len(unit) != naxes: - msg = "Number of units does not match number of axes." - raise ValueError(msg) - self.unit = tuple(u.Unit(au) for au in unit) - else: - self.unit = tuple(u.dimensionless_unscaled for na in range(naxes)) - - if self.axes_names is not None: - if isinstance(self.axes_names, str): - self.axes_names = (self.axes_names,) - else: - self.axes_names = tuple(self.axes_names) - if len(self.axes_names) != naxes: - msg = "Number of axes names does not match number of axes." - raise ValueError(msg) - else: - self.axes_names = tuple([""] * naxes) - - if self.axis_physical_types is not None: - if isinstance(self.axis_physical_types, str): - self.axis_physical_types = (self.axis_physical_types,) - elif not isiterable(self.axis_physical_types): - msg = ( - "axis_physical_types must be of type string or iterable of strings" - ) - raise TypeError(msg) - if len(self.axis_physical_types) != naxes: - msg = f'"axis_physical_types" must be of length {naxes}' - raise ValueError(msg) - ph_type = [] - for axt in self.axis_physical_types: - if axt not in VALID_UCDS and not axt.startswith("custom:"): - ph_type.append(f"custom:{axt}") - else: - ph_type.append(axt) - - validate_physical_types(ph_type) - self.axis_physical_types = tuple(ph_type) - - @property - def _default_axis_physical_types(self): - """ - The default physical types to use for this frame if none are specified - by the user. - """ - return tuple(f"custom:{t}" for t in self.axes_type) - - class BaseCoordinateFrame(abc.ABC): """ API Definition for a Coordinate frame diff --git a/gwcs/coordinate_frames/_properties.py b/gwcs/coordinate_frames/_properties.py new file mode 100644 index 00000000..84dbd15d --- /dev/null +++ b/gwcs/coordinate_frames/_properties.py @@ -0,0 +1,74 @@ +from dataclasses import InitVar, dataclass + +from astropy import units as u +from astropy import utils as astutil +from astropy.utils.misc import isiterable +from astropy.wcs.wcsapi.low_level_api import VALID_UCDS, validate_physical_types + + +@dataclass +class FrameProperties: + naxes: InitVar[int] + axes_type: tuple[str] + unit: tuple[u.Unit] = None + axes_names: tuple[str] = None + axis_physical_types: list[str] = None + + def __post_init__(self, naxes): + if isinstance(self.axes_type, str): + self.axes_type = (self.axes_type,) + else: + self.axes_type = tuple(self.axes_type) + + if len(self.axes_type) != naxes: + msg = "Length of axes_type does not match number of axes." + raise ValueError(msg) + + if self.unit is not None: + unit = tuple(self.unit) if astutil.isiterable(self.unit) else (self.unit,) + if len(unit) != naxes: + msg = "Number of units does not match number of axes." + raise ValueError(msg) + self.unit = tuple(u.Unit(au) for au in unit) + else: + self.unit = tuple(u.dimensionless_unscaled for na in range(naxes)) + + if self.axes_names is not None: + if isinstance(self.axes_names, str): + self.axes_names = (self.axes_names,) + else: + self.axes_names = tuple(self.axes_names) + if len(self.axes_names) != naxes: + msg = "Number of axes names does not match number of axes." + raise ValueError(msg) + else: + self.axes_names = tuple([""] * naxes) + + if self.axis_physical_types is not None: + if isinstance(self.axis_physical_types, str): + self.axis_physical_types = (self.axis_physical_types,) + elif not isiterable(self.axis_physical_types): + msg = ( + "axis_physical_types must be of type string or iterable of strings" + ) + raise TypeError(msg) + if len(self.axis_physical_types) != naxes: + msg = f'"axis_physical_types" must be of length {naxes}' + raise ValueError(msg) + ph_type = [] + for axt in self.axis_physical_types: + if axt not in VALID_UCDS and not axt.startswith("custom:"): + ph_type.append(f"custom:{axt}") + else: + ph_type.append(axt) + + validate_physical_types(ph_type) + self.axis_physical_types = tuple(ph_type) + + @property + def _default_axis_physical_types(self): + """ + The default physical types to use for this frame if none are specified + by the user. + """ + return tuple(f"custom:{t}" for t in self.axes_type) From fbe3200e42a95768415a1e9b22d13cfafa1e547c Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 18 Feb 2025 15:56:21 -0500 Subject: [PATCH 11/14] Type hint the FrameProperties object --- gwcs/_typing.py | 12 ++- gwcs/coordinate_frames/_axis.py | 18 +++++ gwcs/coordinate_frames/_properties.py | 110 +++++++++++++++----------- pyproject.toml | 5 +- 4 files changed, 94 insertions(+), 51 deletions(-) create mode 100644 gwcs/coordinate_frames/_axis.py diff --git a/gwcs/_typing.py b/gwcs/_typing.py index 2300d163..b1b4cf08 100644 --- a/gwcs/_typing.py +++ b/gwcs/_typing.py @@ -5,12 +5,19 @@ import numpy as np import numpy.typing as npt -from astropy.coordinates import SkyCoord, SpectralCoord, StokesCoord +from astropy.coordinates import ( + BaseCoordinateFrame, + SkyCoord, + SpectralCoord, + StokesCoord, +) from astropy.modeling.bounding_box import CompoundBoundingBox, ModelBoundingBox from astropy.time import Time from astropy.units import Quantity __all__ = [ + "AxisPhysicalType", + "AxisPhysicalTypes", "BoundingBox", "Bounds", "HighLevelObject", @@ -55,3 +62,6 @@ WorldAxisComponents: TypeAlias = list[WorldAxisComponent] WorldAxisClasses: TypeAlias = dict[str, WorldAxisClass] + +AxisPhysicalType: TypeAlias = str | BaseCoordinateFrame +AxisPhysicalTypes: TypeAlias = tuple[str | BaseCoordinateFrame, ...] diff --git a/gwcs/coordinate_frames/_axis.py b/gwcs/coordinate_frames/_axis.py new file mode 100644 index 00000000..2d74bf4f --- /dev/null +++ b/gwcs/coordinate_frames/_axis.py @@ -0,0 +1,18 @@ +from enum import StrEnum +from typing import TypeAlias + +__all__ = ["AxesType", "AxisType"] + + +class AxisType(StrEnum): + """ + Enumeration of the Axis types + """ + + SPATIAL = "SPATIAL" + SPECTRAL = "SPECTRAL" + TIME = "TIME" + STOKES = "STOKES" + + +AxesType: TypeAlias = tuple[AxisType, ...] | AxisType diff --git a/gwcs/coordinate_frames/_properties.py b/gwcs/coordinate_frames/_properties.py index 84dbd15d..9c12cdde 100644 --- a/gwcs/coordinate_frames/_properties.py +++ b/gwcs/coordinate_frames/_properties.py @@ -1,74 +1,88 @@ -from dataclasses import InitVar, dataclass +from __future__ import annotations -from astropy import units as u -from astropy import utils as astutil +from collections.abc import Callable + +from astropy.units import Unit, dimensionless_unscaled from astropy.utils.misc import isiterable from astropy.wcs.wcsapi.low_level_api import VALID_UCDS, validate_physical_types +from gwcs._typing import AxisPhysicalType, AxisPhysicalTypes + +from ._axis import AxesType + +__all__ = ["FrameProperties"] + -@dataclass class FrameProperties: - naxes: InitVar[int] - axes_type: tuple[str] - unit: tuple[u.Unit] = None - axes_names: tuple[str] = None - axis_physical_types: list[str] = None - - def __post_init__(self, naxes): - if isinstance(self.axes_type, str): - self.axes_type = (self.axes_type,) - else: - self.axes_type = tuple(self.axes_type) + def __init__( + self, + naxes: int, + axes_type: AxesType, + unit: tuple[Unit, ...] | None = None, + axes_names: str | tuple[str, ...] | None = None, + axis_physical_types: AxisPhysicalType | AxisPhysicalTypes | None = None, + default_axis_physical_types: Callable[[FrameProperties], AxisPhysicalTypes] + | None = None, + ) -> None: + self.naxes = naxes + + self.axes_type = ( + (axes_type,) if isinstance(axes_type, str) else tuple(axes_type) + ) if len(self.axes_type) != naxes: msg = "Length of axes_type does not match number of axes." raise ValueError(msg) - if self.unit is not None: - unit = tuple(self.unit) if astutil.isiterable(self.unit) else (self.unit,) - if len(unit) != naxes: + if unit is None: + self.unit = tuple(dimensionless_unscaled for _ in range(naxes)) + else: + unit_ = tuple(unit) if isiterable(unit) else (unit,) # type: ignore[no-untyped-call] + if len(unit_) != naxes: msg = "Number of units does not match number of axes." raise ValueError(msg) - self.unit = tuple(u.Unit(au) for au in unit) - else: - self.unit = tuple(u.dimensionless_unscaled for na in range(naxes)) + self.unit = tuple(Unit(au) for au in unit_) # type: ignore[no-untyped-call, misc] - if self.axes_names is not None: - if isinstance(self.axes_names, str): - self.axes_names = (self.axes_names,) - else: - self.axes_names = tuple(self.axes_names) + if axes_names is None: + self.axes_names = tuple([""] * naxes) + else: + self.axes_names = ( + (axes_names,) if isinstance(axes_names, str) else tuple(axes_names) + ) if len(self.axes_names) != naxes: msg = "Number of axes names does not match number of axes." raise ValueError(msg) + + if axis_physical_types is None: + if default_axis_physical_types is None: + default_axis_physical_types = self._default_axis_physical_types + + self.axis_physical_types: AxisPhysicalTypes = default_axis_physical_types( + self + ) else: - self.axes_names = tuple([""] * naxes) + self.axis_physical_types = ( + (axis_physical_types,) + if isinstance(axis_physical_types, str) + else tuple(axis_physical_types) + ) - if self.axis_physical_types is not None: - if isinstance(self.axis_physical_types, str): - self.axis_physical_types = (self.axis_physical_types,) - elif not isiterable(self.axis_physical_types): - msg = ( - "axis_physical_types must be of type string or iterable of strings" - ) - raise TypeError(msg) if len(self.axis_physical_types) != naxes: msg = f'"axis_physical_types" must be of length {naxes}' raise ValueError(msg) - ph_type = [] - for axt in self.axis_physical_types: - if axt not in VALID_UCDS and not axt.startswith("custom:"): - ph_type.append(f"custom:{axt}") - else: - ph_type.append(axt) - - validate_physical_types(ph_type) - self.axis_physical_types = tuple(ph_type) - - @property - def _default_axis_physical_types(self): + + self.axis_physical_types = tuple( + f"custom:{axt}" + if axt not in VALID_UCDS and not axt.startswith("custom:") + else axt + for axt in self.axis_physical_types + ) + validate_physical_types(self.axis_physical_types) # type: ignore[no-untyped-call] + + @staticmethod + def _default_axis_physical_types(properties: FrameProperties) -> AxisPhysicalTypes: """ The default physical types to use for this frame if none are specified by the user. """ - return tuple(f"custom:{t}" for t in self.axes_type) + return tuple(f"custom:{t}" for t in properties.axes_type) diff --git a/pyproject.toml b/pyproject.toml index 859b7164..93cc57f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -209,7 +209,8 @@ module = [ # "astropy.table.*", "astropy.time.*", "astropy.units.*", - # "astropy.utils.*", + "astropy.utils.*", + "astropy.wcs.*", ] follow_untyped_imports = true @@ -220,7 +221,7 @@ module = [ "gwcs.converters.*", "gwcs.wcs.*", "gwcs.api.*", - "gwcs.coordinate_frames.*", + "gwcs.coordinate_frames._coordinate_frames.*", "gwcs.examples.*", "gwcs.extension.*", "gwcs.geometry.*", From f24528530ddea0253b46d8487c2ee08bf2308806 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 18 Feb 2025 16:20:20 -0500 Subject: [PATCH 12/14] Refactor BaseCoordinateFrame into its own module --- gwcs/coordinate_frames/__init__.py | 2 +- gwcs/coordinate_frames/_base.py | 168 +++++++++++++++++++ gwcs/coordinate_frames/_coordinate_frames.py | 143 +--------------- 3 files changed, 170 insertions(+), 143 deletions(-) create mode 100644 gwcs/coordinate_frames/_base.py diff --git a/gwcs/coordinate_frames/__init__.py b/gwcs/coordinate_frames/__init__.py index a6cde3a6..553d3948 100644 --- a/gwcs/coordinate_frames/__init__.py +++ b/gwcs/coordinate_frames/__init__.py @@ -116,8 +116,8 @@ """ +from ._base import BaseCoordinateFrame from ._coordinate_frames import ( - BaseCoordinateFrame, CelestialFrame, CompositeFrame, CoordinateFrame, diff --git a/gwcs/coordinate_frames/_base.py b/gwcs/coordinate_frames/_base.py new file mode 100644 index 00000000..d1fa3cde --- /dev/null +++ b/gwcs/coordinate_frames/_base.py @@ -0,0 +1,168 @@ +import abc +from typing import cast + +import numpy as np +import numpy.typing as npt +from astropy import units as u +from astropy.coordinates import BaseCoordinateFrame as _BaseCoordinateFrame + +from gwcs._typing import ( + AxisPhysicalTypes, + LowLevelUnitValue, + WorldAxisClasses, + WorldAxisComponents, +) + +from ._axis import AxesType +from ._properties import FrameProperties + +__all__ = ["BaseCoordinateFrame"] + + +class BaseCoordinateFrame(abc.ABC): + """ + API Definition for a Coordinate frame + """ + + _prop: FrameProperties + """ + The FrameProperties object holding properties in native frame order. + """ + + @property + @abc.abstractmethod + def naxes(self) -> int: + """ + The number of axes described by this frame. + """ + + @property + @abc.abstractmethod + def name(self) -> str: + """ + The name of the coordinate frame. + """ + + @property + @abc.abstractmethod + def unit(self) -> tuple[u.Unit, ...]: + """ + The units of the axes in this frame. + """ + + @property + @abc.abstractmethod + def axes_names(self) -> tuple[str, ...]: + """ + Names describing the axes of the frame. + """ + + @property + @abc.abstractmethod + def axes_order(self) -> tuple[int, ...]: + """ + The position of the axes in the frame in the transform. + """ + + @property + @abc.abstractmethod + def reference_frame(self) -> _BaseCoordinateFrame: + """ + The reference frame of the coordinates described by this frame. + + This is usually an Astropy object such as ``SkyCoord`` or ``Time``. + """ + + @property + @abc.abstractmethod + def axes_type(self) -> AxesType: + """ + An upcase string describing the type of the axis. + + Known values are ``"SPATIAL", "TEMPORAL", "STOKES", "SPECTRAL", "PIXEL"``. + """ + + @property + @abc.abstractmethod + def axis_physical_types(self) -> AxisPhysicalTypes: + """ + The UCD 1+ physical types for the axes, in frame order. + """ + + @property + @abc.abstractmethod + def world_axis_object_classes(self) -> WorldAxisClasses: + """ + The APE 14 object classes for this frame. + + See Also + -------- + astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_classes + """ + + @property + def world_axis_object_components(self) -> WorldAxisComponents: + """ + The APE 14 object components for this frame. + + See Also + -------- + astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_components + """ + if self.naxes == 1: + return self._native_world_axis_object_components + + # If we have more than one axis then we should sort the native + # components by the axes_order. + ordered = np.array(self._native_world_axis_object_components, dtype=object)[ + np.argsort(self.axes_order) + ] + return list(map(tuple, ordered)) + + @property + @abc.abstractmethod + def _native_world_axis_object_components(self) -> WorldAxisComponents: + """ + This property holds the "native" frame order of the components. + + The native order of the components is the order the frame assumes the + axes are in when creating the high level objects, for example + ``CelestialFrame`` creates ``SkyCoord`` objects which are in lon, lat + order (in their positional args). + + This property is used both to construct the ordered + ``world_axis_object_components`` property as well as by `CompositeFrame` + to be able to get the components in their native order. + """ + + def add_units( + self, arrays: tuple[LowLevelUnitValue, ...] + ) -> tuple[u.Quantity, ...]: + """ + Add units to the arrays + """ + return tuple( + u.Quantity(array, unit=unit) # type: ignore[arg-type] + for array, unit in zip(arrays, self.unit, strict=True) + ) + + def remove_units( + self, arrays: tuple[LowLevelUnitValue, ...] | LowLevelUnitValue + ) -> tuple[npt.NDArray[np.number], ...]: + """ + Remove units from the input arrays + """ + if self.naxes == 1: + arrays = (cast(LowLevelUnitValue, arrays),) + + return tuple( + cast( + npt.NDArray[np.number], + array.to_value(unit) # type: ignore[no-untyped-call] + if isinstance(array, u.Quantity) + else array, + ) + for array, unit in zip( + cast(tuple[LowLevelUnitValue, ...], arrays), self.unit, strict=True + ) + ) diff --git a/gwcs/coordinate_frames/_coordinate_frames.py b/gwcs/coordinate_frames/_coordinate_frames.py index 45aa973f..79048f4c 100644 --- a/gwcs/coordinate_frames/_coordinate_frames.py +++ b/gwcs/coordinate_frames/_coordinate_frames.py @@ -1,6 +1,5 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst -import abc import contextlib import logging import numbers @@ -18,10 +17,10 @@ values_to_high_level_objects, ) +from ._base import BaseCoordinateFrame from ._properties import FrameProperties __all__ = [ - "BaseCoordinateFrame", "CelestialFrame", "CompositeFrame", "CoordinateFrame", @@ -87,146 +86,6 @@ def get_ctype_from_ucd(ucd): return UCD1_TO_CTYPE.get(ucd, "") -class BaseCoordinateFrame(abc.ABC): - """ - API Definition for a Coordinate frame - """ - - _prop: FrameProperties - """ - The FrameProperties object holding properties in native frame order. - """ - - @property - @abc.abstractmethod - def naxes(self) -> int: - """ - The number of axes described by this frame. - """ - - @property - @abc.abstractmethod - def name(self) -> str: - """ - The name of the coordinate frame. - """ - - @property - @abc.abstractmethod - def unit(self) -> tuple[u.Unit, ...]: - """ - The units of the axes in this frame. - """ - - @property - @abc.abstractmethod - def axes_names(self) -> tuple[str, ...]: - """ - Names describing the axes of the frame. - """ - - @property - @abc.abstractmethod - def axes_order(self) -> tuple[int, ...]: - """ - The position of the axes in the frame in the transform. - """ - - @property - @abc.abstractmethod - def reference_frame(self): - """ - The reference frame of the coordinates described by this frame. - - This is usually an Astropy object such as ``SkyCoord`` or ``Time``. - """ - - @property - @abc.abstractmethod - def axes_type(self): - """ - An upcase string describing the type of the axis. - - Known values are ``"SPATIAL", "TEMPORAL", "STOKES", "SPECTRAL", "PIXEL"``. - """ - - @property - @abc.abstractmethod - def axis_physical_types(self): - """ - The UCD 1+ physical types for the axes, in frame order. - """ - - @property - @abc.abstractmethod - def world_axis_object_classes(self): - """ - The APE 14 object classes for this frame. - - See Also - -------- - astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_classes - """ - - @property - def world_axis_object_components(self): - """ - The APE 14 object components for this frame. - - See Also - -------- - astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_components - """ - if self.naxes == 1: - return self._native_world_axis_object_components - - # If we have more than one axis then we should sort the native - # components by the axes_order. - ordered = np.array(self._native_world_axis_object_components, dtype=object)[ - np.argsort(self.axes_order) - ] - return list(map(tuple, ordered)) - - @property - @abc.abstractmethod - def _native_world_axis_object_components(self): - """ - This property holds the "native" frame order of the components. - - The native order of the components is the order the frame assumes the - axes are in when creating the high level objects, for example - ``CelestialFrame`` creates ``SkyCoord`` objects which are in lon, lat - order (in their positional args). - - This property is used both to construct the ordered - ``world_axis_object_components`` property as well as by `CompositeFrame` - to be able to get the components in their native order. - """ - - def add_units(self, arrays: u.Quantity | np.ndarray | float) -> tuple[u.Quantity]: - """ - Add units to the arrays - """ - return tuple( - u.Quantity(array, unit=unit) - for array, unit in zip(arrays, self.unit, strict=True) - ) - - def remove_units( - self, arrays: u.Quantity | np.ndarray | float - ) -> tuple[np.ndarray]: - """ - Remove units from the input arrays - """ - if self.naxes == 1: - arrays = (arrays,) - - return tuple( - array.to_value(unit) if isinstance(array, u.Quantity) else array - for array, unit in zip(arrays, self.unit, strict=True) - ) - - class CoordinateFrame(BaseCoordinateFrame): """ Base class for Coordinate Frames. From 426187b5075326ed5e18e7525554fc9c7146a56b Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 18 Feb 2025 17:41:01 -0500 Subject: [PATCH 13/14] Refactor CoordinateFrame into its own module (_core) --- gwcs/_typing.py | 15 +- gwcs/api.py | 36 +- gwcs/coordinate_frames/__init__.py | 2 +- gwcs/coordinate_frames/_axis.py | 6 + gwcs/coordinate_frames/_base.py | 17 +- gwcs/coordinate_frames/_coordinate_frames.py | 328 +++---------------- gwcs/coordinate_frames/_core.py | 262 +++++++++++++++ gwcs/tests/test_coordinate_systems.py | 6 +- 8 files changed, 367 insertions(+), 305 deletions(-) create mode 100644 gwcs/coordinate_frames/_core.py diff --git a/gwcs/_typing.py b/gwcs/_typing.py index b1b4cf08..6aed2a9c 100644 --- a/gwcs/_typing.py +++ b/gwcs/_typing.py @@ -21,6 +21,7 @@ "BoundingBox", "Bounds", "HighLevelObject", + "HighLevelObjects", "Interval", "LowLevelArrays", "LowLevelUnitArrays", @@ -28,9 +29,6 @@ "LowLevelValue", "OutputLowLevelArray", "Real", - "WorldAxisClass", - "WorldAxisComponent", - "WorldAxisComponents", ] Real: TypeAlias = int | float | Fraction | np.integer | np.floating @@ -46,7 +44,7 @@ LowLevelUnitValue: TypeAlias = LowLevelValue | Quantity # This is to represent all the values together for a single low-level function. -LowLevelArrays: TypeAlias = tuple[LowLevelValue, ...] +LowLevelArrays: TypeAlias = tuple[LowLevelValue, ...] | LowLevelValue LowLevelUnitArrays: TypeAlias = tuple[LowLevelUnitValue, ...] # This is to represent a general array output from a low-level function. @@ -54,14 +52,7 @@ OutputLowLevelArray: TypeAlias = LowLevelValue | LowLevelArrays HighLevelObject: TypeAlias = Time | SkyCoord | SpectralCoord | StokesCoord | Quantity - -WorldAxisComponent: TypeAlias = tuple[str, str | int, str] -WorldAxisClass: TypeAlias = tuple[ - type | str, tuple[int | None, ...], dict[str, HighLevelObject] -] - -WorldAxisComponents: TypeAlias = list[WorldAxisComponent] -WorldAxisClasses: TypeAlias = dict[str, WorldAxisClass] +HighLevelObjects: TypeAlias = tuple[HighLevelObject, ...] | HighLevelObject AxisPhysicalType: TypeAlias = str | BaseCoordinateFrame AxisPhysicalTypes: TypeAlias = tuple[str | BaseCoordinateFrame, ...] diff --git a/gwcs/api.py b/gwcs/api.py index 4a563eda..c00c1c9e 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -5,13 +5,47 @@ """ +from typing import Any, NamedTuple, TypeAlias + import astropy.units as u from astropy.modeling import separable from astropy.wcs.wcsapi import BaseLowLevelWCS, HighLevelWCSMixin from gwcs import utils -__all__ = ["GWCSAPIMixin"] +__all__ = [ + "GWCSAPIMixin", + "WorldAxisClass", + "WorldAxisClasses", + "WorldAxisComponent", + "WorldAxisComponents", +] + + +class WorldAxisClass(NamedTuple): + """ + Named tuple for the world_axis_object_classes WCS property + """ + + object_type: type | str + args: tuple[int | None, ...] + kwargs: dict[str, Any] + + +WorldAxisClasses: TypeAlias = dict[str | int, WorldAxisClass] + + +class WorldAxisComponent(NamedTuple): + """ + Named tuple for the world_axis_object_components WCS property + """ + + name: str + key: str | int + property_name: str + + +WorldAxisComponents: TypeAlias = list[WorldAxisComponent] class GWCSAPIMixin(BaseLowLevelWCS, HighLevelWCSMixin): diff --git a/gwcs/coordinate_frames/__init__.py b/gwcs/coordinate_frames/__init__.py index 553d3948..38654af7 100644 --- a/gwcs/coordinate_frames/__init__.py +++ b/gwcs/coordinate_frames/__init__.py @@ -120,7 +120,6 @@ from ._coordinate_frames import ( CelestialFrame, CompositeFrame, - CoordinateFrame, EmptyFrame, Frame2D, SpectralFrame, @@ -128,6 +127,7 @@ TemporalFrame, get_ctype_from_ucd, ) +from ._core import CoordinateFrame __all__ = [ "BaseCoordinateFrame", diff --git a/gwcs/coordinate_frames/_axis.py b/gwcs/coordinate_frames/_axis.py index 2d74bf4f..f05b1c94 100644 --- a/gwcs/coordinate_frames/_axis.py +++ b/gwcs/coordinate_frames/_axis.py @@ -14,5 +14,11 @@ class AxisType(StrEnum): TIME = "TIME" STOKES = "STOKES" + def __str__(self) -> str: + return self.value + + def __repr__(self) -> str: + return self.value + AxesType: TypeAlias = tuple[AxisType, ...] | AxisType diff --git a/gwcs/coordinate_frames/_base.py b/gwcs/coordinate_frames/_base.py index d1fa3cde..6120a42d 100644 --- a/gwcs/coordinate_frames/_base.py +++ b/gwcs/coordinate_frames/_base.py @@ -1,5 +1,6 @@ import abc -from typing import cast +from collections.abc import Sequence +from typing import Any, cast import numpy as np import numpy.typing as npt @@ -9,9 +10,8 @@ from gwcs._typing import ( AxisPhysicalTypes, LowLevelUnitValue, - WorldAxisClasses, - WorldAxisComponents, ) +from gwcs.api import WorldAxisClasses, WorldAxisComponent, WorldAxisComponents from ._axis import AxesType from ._properties import FrameProperties @@ -66,7 +66,7 @@ def axes_order(self) -> tuple[int, ...]: @property @abc.abstractmethod - def reference_frame(self) -> _BaseCoordinateFrame: + def reference_frame(self) -> _BaseCoordinateFrame | None: """ The reference frame of the coordinates described by this frame. @@ -116,8 +116,13 @@ def world_axis_object_components(self) -> WorldAxisComponents: # components by the axes_order. ordered = np.array(self._native_world_axis_object_components, dtype=object)[ np.argsort(self.axes_order) - ] - return list(map(tuple, ordered)) + ].tolist() + + # Makes MyPy happy + def _func(arg: Sequence[Any]) -> WorldAxisComponent: + return WorldAxisComponent(*arg) + + return list(map(_func, ordered)) @property @abc.abstractmethod diff --git a/gwcs/coordinate_frames/_coordinate_frames.py b/gwcs/coordinate_frames/_coordinate_frames.py index 79048f4c..f98ed26d 100644 --- a/gwcs/coordinate_frames/_coordinate_frames.py +++ b/gwcs/coordinate_frames/_coordinate_frames.py @@ -2,7 +2,6 @@ import contextlib import logging -import numbers from collections import defaultdict import numpy as np @@ -10,20 +9,13 @@ from astropy import time from astropy import units as u from astropy.coordinates import StokesCoord -from astropy.utils.misc import isiterable from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1 -from astropy.wcs.wcsapi.high_level_api import ( - high_level_objects_to_values, - values_to_high_level_objects, -) -from ._base import BaseCoordinateFrame -from ._properties import FrameProperties +from ._core import CoordinateFrame __all__ = [ "CelestialFrame", "CompositeFrame", - "CoordinateFrame", "EmptyFrame", "Frame2D", "SpectralFrame", @@ -86,226 +78,6 @@ def get_ctype_from_ucd(ucd): return UCD1_TO_CTYPE.get(ucd, "") -class CoordinateFrame(BaseCoordinateFrame): - """ - Base class for Coordinate Frames. - - Parameters - ---------- - naxes : int - Number of axes. - axes_type : str - One of ["SPATIAL", "SPECTRAL", "TIME"] - axes_order : tuple of int - A dimension in the input data that corresponds to this axis. - reference_frame : astropy.coordinates.builtin_frames - Reference frame (usually used with output_frame to convert to world - coordinate objects). - unit : list of astropy.units.Unit - Unit for each axis. - axes_names : list - Names of the axes in this frame. - name : str - Name of this frame. - """ - - def __init__( - self, - naxes, - axes_type, - axes_order, - reference_frame=None, - unit=None, - axes_names=None, - name=None, - axis_physical_types=None, - ): - self._naxes = naxes - self._axes_order = tuple(axes_order) - self._reference_frame = reference_frame - - if name is None: - self._name = self.__class__.__name__ - else: - self._name = name - - if len(self._axes_order) != naxes: - msg = "Length of axes_order does not match number of axes." - raise ValueError(msg) - - if isinstance(axes_type, str): - axes_type = (axes_type,) - - self._prop = FrameProperties( - naxes, - axes_type, - unit, - axes_names, - axis_physical_types or self._default_axis_physical_types(axes_type), - ) - - super().__init__() - - def _default_axis_physical_types(self, axes_type): - """ - The default physical types to use for this frame if none are specified - by the user. - """ - return tuple(f"custom:{t}" for t in axes_type) - - def __repr__(self): - fmt = ( - f'<{self.__class__.__name__}(name="{self.name}", unit={self.unit}, ' - f"axes_names={self.axes_names}, axes_order={self.axes_order}" - ) - if self.reference_frame is not None: - fmt += f", reference_frame={self.reference_frame}" - fmt += ")>" - return fmt - - def __str__(self): - if self._name is not None: - return self._name - return self.__class__.__name__ - - def _sort_property(self, prop): - sorted_prop = sorted( - zip(prop, self.axes_order, strict=False), key=lambda x: x[1] - ) - return tuple([t[0] for t in sorted_prop]) - - @property - def name(self): - """A custom name of this frame.""" - return self._name - - @name.setter - def name(self, val): - """A custom name of this frame.""" - self._name = val - - @property - def naxes(self): - """The number of axes in this frame.""" - return self._naxes - - @property - def unit(self): - """The unit of this frame.""" - return self._sort_property(self._prop.unit) - - @property - def axes_names(self): - """Names of axes in the frame.""" - return self._sort_property(self._prop.axes_names) - - @property - def axes_order(self): - """A tuple of indices which map inputs to axes.""" - return self._axes_order - - @property - def reference_frame(self): - """Reference frame, used to convert to world coordinate objects.""" - return self._reference_frame - - @property - def axes_type(self): - """Type of this frame : 'SPATIAL', 'SPECTRAL', 'TIME'.""" - return self._sort_property(self._prop.axes_type) - - @property - def axis_physical_types(self): - """ - The axis physical types for this frame. - - These physical types are the types in frame order, not transform order. - """ - return self._sort_property(self._prop.axis_physical_types) - - @property - def world_axis_object_classes(self): - return { - f"{at}{i}" if i != 0 else at: (u.Quantity, (), {"unit": unit}) - for i, (at, unit) in enumerate(zip(self.axes_type, self.unit, strict=False)) - } - - @property - def _native_world_axis_object_components(self): - return [ - (f"{at}{i}" if i != 0 else at, 0, "value") - for i, at in enumerate(self._prop.axes_type) - ] - - @property - def serialized_classes(self): - """ - This property is used by the low level WCS API in Astropy. - - By providing it we can duck type as a low level WCS object. - """ - return False - - def to_high_level_coordinates(self, *values): - """ - Convert "values" to high level coordinate objects described by this frame. - - "values" are the coordinates in array or scalar form, and high level - objects are things such as ``SkyCoord`` or ``Quantity``. See - :ref:`wcsapi` for details. - - Parameters - ---------- - values : `numbers.Number`, `numpy.ndarray`, or `~astropy.units.Quantity` - ``naxis`` number of coordinates as scalars or arrays. - - Returns - ------- - high_level_coordinates - One (or more) high level object describing the coordinate. - """ - # We allow Quantity-like objects here which values_to_high_level_objects - # does not. - values = [ - v.to_value(unit) if hasattr(v, "to_value") else v - for v, unit in zip(values, self.unit, strict=False) - ] - - if not all( - isinstance(v, numbers.Number) or type(v) is np.ndarray for v in values - ): - msg = "All values should be a scalar number or a numpy array." - raise TypeError(msg) - - high_level = values_to_high_level_objects(*values, low_level_wcs=self) - if len(high_level) == 1: - high_level = high_level[0] - return high_level - - def from_high_level_coordinates(self, *high_level_coords): - """ - Convert high level coordinate objects to "values" as described by this frame. - - "values" are the coordinates in array or scalar form, and high level - objects are things such as ``SkyCoord`` or ``Quantity``. See - :ref:`wcsapi` for details. - - Parameters - ---------- - high_level_coordinates - One (or more) high level object describing the coordinate. - - Returns - ------- - values : `numbers.Number` or `numpy.ndarray` - ``naxis`` number of coordinates as scalars or arrays. - """ - values = high_level_objects_to_values(*high_level_coords, low_level_wcs=self) - if len(values) == 1: - values = values[0] - return values - - class EmptyFrame(CoordinateFrame): """ Represents a "default" detector frame. This is for use as the default value @@ -420,11 +192,15 @@ def __init__( and not isinstance(reference_frame, str) and reference_frame.name.upper() in STANDARD_REFERENCE_FRAMES ): - _axes_names = list(reference_frame.representation_component_names.values()) - if "distance" in _axes_names: - _axes_names.remove("distance") - if axes_names is None: - axes_names = _axes_names + _axes_names = ( + tuple( + n + for n in reference_frame.representation_component_names.values() + if n != "distance" + ) + if axes_names is None + else axes_names + ) naxes = len(_axes_names) self.native_axes_order = tuple(range(naxes)) @@ -434,33 +210,30 @@ def __init__( unit = tuple([u.degree] * naxes) axes_type = ["SPATIAL"] * naxes - pht = axis_physical_types or self._default_axis_physical_types( - reference_frame, axes_names - ) super().__init__( naxes=naxes, axes_type=axes_type, axes_order=axes_order, reference_frame=reference_frame, unit=unit, - axes_names=axes_names, + axes_names=_axes_names, name=name, - axis_physical_types=pht, + axis_physical_types=axis_physical_types, ) - def _default_axis_physical_types(self, reference_frame, axes_names): - if isinstance(reference_frame, coord.Galactic): + def _default_axis_physical_types(self, properties): + if isinstance(self.reference_frame, coord.Galactic): return "pos.galactic.lon", "pos.galactic.lat" if isinstance( - reference_frame, + self.reference_frame, coord.GeocentricTrueEcliptic | coord.GCRS | coord.PrecessedGeocentric, ): return "pos.bodyrc.lon", "pos.bodyrc.lat" - if isinstance(reference_frame, coord.builtin_frames.BaseRADecFrame): + if isinstance(self.reference_frame, coord.builtin_frames.BaseRADecFrame): return "pos.eq.ra", "pos.eq.dec" - if isinstance(reference_frame, coord.builtin_frames.BaseEclipticFrame): + if isinstance(self.reference_frame, coord.builtin_frames.BaseEclipticFrame): return "pos.ecliptic.lon", "pos.ecliptic.lat" - return tuple(f"custom:{t}" for t in axes_names) + return tuple(f"custom:{t}" for t in properties.axes_names) @property def world_axis_object_classes(self): @@ -509,11 +282,6 @@ def __init__( name=None, axis_physical_types=None, ): - if not isiterable(unit): - unit = (unit,) - unit = [u.Unit(un) for un in unit] - pht = axis_physical_types or self._default_axis_physical_types(unit) - super().__init__( naxes=1, axes_type="SPECTRAL", @@ -522,17 +290,17 @@ def __init__( reference_frame=reference_frame, unit=unit, name=name, - axis_physical_types=pht, + axis_physical_types=axis_physical_types, ) - def _default_axis_physical_types(self, unit): - if unit[0].physical_type == "frequency": + def _default_axis_physical_types(self, properties): + if properties.unit[0].physical_type == "frequency": return ("em.freq",) - if unit[0].physical_type == "length": + if properties.unit[0].physical_type == "length": return ("em.wl",) - if unit[0].physical_type == "energy": + if properties.unit[0].physical_type == "energy": return ("em.energy",) - if unit[0].physical_type == "speed": + if properties.unit[0].physical_type == "speed": return ("spect.dopplerVeloc",) logging.warning( "Physical type may be ambiguous. Consider " @@ -540,7 +308,7 @@ def _default_axis_physical_types(self, unit): "either 'spect.dopplerVeloc.optical' or " "'spect.dopplerVeloc.radio'." ) - return (f"custom:{unit[0].physical_type}",) + return (f"custom:{properties.unit[0].physical_type}",) @property def world_axis_object_classes(self): @@ -581,30 +349,31 @@ def __init__( name=None, axis_physical_types=None, ): - axes_names = ( - axes_names - or f"{reference_frame.format}({reference_frame.scale}; " - f"{reference_frame.location}" + _axes_names = ( + ( + f"{reference_frame.format}({reference_frame.scale}; " + f"{reference_frame.location}", + ) + if axes_names is None + else axes_names ) - pht = axis_physical_types or self._default_axis_physical_types() - super().__init__( naxes=1, axes_type="TIME", axes_order=axes_order, - axes_names=axes_names, + axes_names=_axes_names, reference_frame=reference_frame, unit=unit, name=name, - axis_physical_types=pht, + axis_physical_types=axis_physical_types, ) self._attrs = {} for a in self.reference_frame.info._represent_as_dict_extra_attrs: with contextlib.suppress(AttributeError): self._attrs[a] = getattr(self.reference_frame, a) - def _default_axis_physical_types(self): + def _default_axis_physical_types(self, properties): return ("time",) def _convert_to_time(self, dt, *, unit, **kwargs): @@ -682,10 +451,10 @@ def __init__(self, frames, name=None): super().__init__( naxes, - axes_type=axes_type, - axes_order=axes_order, - unit=unit, - axes_names=axes_names, + axes_type=tuple(axes_type), + axes_order=tuple(axes_order), + unit=tuple(unit), + axes_names=tuple(axes_names), axis_physical_types=tuple(ph_type), name=name, ) @@ -774,8 +543,6 @@ def __init__( name=None, axis_physical_types=None, ): - pht = axis_physical_types or self._default_axis_physical_types() - super().__init__( 1, ["STOKES"], @@ -783,10 +550,10 @@ def __init__( name=name, axes_names=axes_names, unit=u.one, - axis_physical_types=pht, + axis_physical_types=axis_physical_types, ) - def _default_axis_physical_types(self): + def _default_axis_physical_types(self, properties): return ("phys.polarization.stokes",) @property @@ -831,9 +598,6 @@ def __init__( ): if axes_type is None: axes_type = ["SPATIAL", "SPATIAL"] - pht = axis_physical_types or self._default_axis_physical_types( - axes_names, axes_type - ) super().__init__( naxes=2, @@ -842,13 +606,13 @@ def __init__( name=name, axes_names=axes_names, unit=unit, - axis_physical_types=pht, + axis_physical_types=axis_physical_types, ) - def _default_axis_physical_types(self, axes_names, axes_type): - if axes_names is not None and all(axes_names): - ph_type = axes_names + def _default_axis_physical_types(self, properties): + if properties.axes_names is not None and all(properties.axes_names): + ph_type = properties.axes_names else: - ph_type = axes_type + ph_type = properties.axes_type return tuple(f"custom:{t}" for t in ph_type) diff --git a/gwcs/coordinate_frames/_core.py b/gwcs/coordinate_frames/_core.py new file mode 100644 index 00000000..85b1984c --- /dev/null +++ b/gwcs/coordinate_frames/_core.py @@ -0,0 +1,262 @@ +import numbers +from typing import TypeVar + +import numpy as np +from astropy import units as u +from astropy.coordinates import BaseCoordinateFrame as _BaseCoordinateFrame +from astropy.wcs.wcsapi.high_level_api import ( + high_level_objects_to_values, + values_to_high_level_objects, +) + +from gwcs._typing import ( + AxisPhysicalTypes, + HighLevelObject, + HighLevelObjects, + LowLevelArrays, + LowLevelValue, +) +from gwcs.api import ( + WorldAxisClass, + WorldAxisClasses, + WorldAxisComponent, + WorldAxisComponents, +) + +from ._axis import AxesType +from ._base import BaseCoordinateFrame +from ._properties import FrameProperties + +__all__ = ["CoordinateFrame"] + +_T = TypeVar("_T") + + +class CoordinateFrame(BaseCoordinateFrame): + """ + Base class for Coordinate Frames. + + Parameters + ---------- + naxes + Number of axes. + axes_type + One of ["SPATIAL", "SPECTRAL", "TIME"] + axes_order + A dimension in the input data that corresponds to this axis. + reference_frame + Reference frame (usually used with output_frame to convert to world + coordinate objects). + unit + Unit for each axis. + axes_names + Names of the axes in this frame. + name + Name of this frame. + """ + + def __init__( + self, + naxes: int, + axes_type: AxesType, + axes_order: tuple[int, ...], + reference_frame: _BaseCoordinateFrame | None = None, + unit: tuple[u.Unit, ...] | None = None, + axes_names: tuple[str, ...] | None = None, + name: str | None = None, + axis_physical_types: AxisPhysicalTypes | None = None, + ) -> None: + self._naxes = naxes + self._axes_order = tuple(axes_order) + self._reference_frame = reference_frame + + if name is None: + self._name = type(self).__name__ + else: + self._name = name + + if len(self._axes_order) != naxes: + msg = "Length of axes_order does not match number of axes." + raise ValueError(msg) + + if isinstance(axes_type, str): + axes_type = (axes_type,) + + self._prop = FrameProperties( + naxes, + axes_type, + unit, + axes_names, + axis_physical_types, + self._default_axis_physical_types, + ) + + super().__init__() + + def _default_axis_physical_types( + self, properties: FrameProperties + ) -> AxisPhysicalTypes: + """ + The default physical types to use for this frame if none are specified + by the user. + """ + return tuple(f"custom:{t}" for t in properties.axes_type) + + def __repr__(self) -> str: + fmt = ( + f'<{self.__class__.__name__}(name="{self.name}", unit={self.unit}, ' + f"axes_names={self.axes_names}, axes_order={self.axes_order}" + ) + if self.reference_frame is not None: + fmt += f", reference_frame={self.reference_frame}" + fmt += ")>" + return fmt + + def __str__(self) -> str: + return self._name + + def _sort_property(self, prop: tuple[_T, ...]) -> tuple[_T, ...]: + sorted_prop = sorted( + zip(prop, self.axes_order, strict=False), key=lambda x: x[1] + ) + return tuple([t[0] for t in sorted_prop]) + + @property + def name(self) -> str: + """A custom name of this frame.""" + return self._name + + @name.setter + def name(self, val: str) -> None: + """A custom name of this frame.""" + self._name = val + + @property + def naxes(self) -> int: + """The number of axes in this frame.""" + return self._naxes + + @property + def unit(self) -> tuple[u.Unit, ...]: + """The unit of this frame.""" + return self._sort_property(self._prop.unit) # type: ignore[arg-type] + + @property + def axes_names(self) -> tuple[str, ...]: + """Names of axes in the frame.""" + return self._sort_property(self._prop.axes_names) + + @property + def axes_order(self) -> tuple[int, ...]: + """A tuple of indices which map inputs to axes.""" + return self._axes_order + + @property + def reference_frame(self) -> _BaseCoordinateFrame | None: + """Reference frame, used to convert to world coordinate objects.""" + return self._reference_frame + + @property + def axes_type(self) -> AxesType: + """Type of this frame : 'SPATIAL', 'SPECTRAL', 'TIME'.""" + return self._sort_property(self._prop.axes_type) + + @property + def axis_physical_types(self) -> AxisPhysicalTypes: + """ + The axis physical types for this frame. + + These physical types are the types in frame order, not transform order. + """ + return self._sort_property(self._prop.axis_physical_types) + + @property + def world_axis_object_classes(self) -> WorldAxisClasses: + return { + f"{at}{i}" if i != 0 else str(at): WorldAxisClass( + u.Quantity, (), {"unit": unit} + ) + for i, (at, unit) in enumerate(zip(self.axes_type, self.unit, strict=False)) + } + + @property + def _native_world_axis_object_components(self) -> WorldAxisComponents: + return [ + WorldAxisComponent(f"{at}{i}" if i != 0 else at, 0, "value") + for i, at in enumerate(self._prop.axes_type) + ] + + @property + def serialized_classes(self) -> bool: + """ + This property is used by the low level WCS API in Astropy. + + By providing it we can duck type as a low level WCS object. + """ + return False + + def to_high_level_coordinates(self, *values: LowLevelValue) -> HighLevelObjects: + """ + Convert "values" to high level coordinate objects described by this frame. + + "values" are the coordinates in array or scalar form, and high level + objects are things such as ``SkyCoord`` or ``Quantity``. See + :ref:`wcsapi` for details. + + Parameters + ---------- + values : `numbers.Number`, `numpy.ndarray`, or `~astropy.units.Quantity` + ``naxis`` number of coordinates as scalars or arrays. + + Returns + ------- + high_level_coordinates + One (or more) high level object describing the coordinate. + """ + # We allow Quantity-like objects here which values_to_high_level_objects + # does not. + values = tuple( + v.to_value(unit) if hasattr(v, "to_value") else v + for v, unit in zip(values, self.unit, strict=False) + ) + + if not all( + isinstance(v, numbers.Number) or type(v) is np.ndarray for v in values + ): + msg = "All values should be a scalar number or a numpy array." + raise TypeError(msg) + + high_level: list[HighLevelObject] = values_to_high_level_objects( + *values, low_level_wcs=self + ) # type: ignore[no-untyped-call] + if len(high_level) == 1: + return high_level[0] + + return tuple(high_level) + + def from_high_level_coordinates( + self, *high_level_coords: HighLevelObject + ) -> LowLevelArrays: + """ + Convert high level coordinate objects to "values" as described by this frame. + + "values" are the coordinates in array or scalar form, and high level + objects are things such as ``SkyCoord`` or ``Quantity``. See + :ref:`wcsapi` for details. + + Parameters + ---------- + high_level_coordinates + One (or more) high level object describing the coordinate. + + Returns + ------- + values : `numbers.Number` or `numpy.ndarray` + ``naxis`` number of coordinates as scalars or arrays. + """ + values: list[LowLevelValue] = high_level_objects_to_values( + *high_level_coords, low_level_wcs=self + ) # type: ignore[no-untyped-call] + if len(values) == 1: + return values[0] + return tuple(values) diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index ed8e0ee5..cfcad311 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -106,9 +106,9 @@ def coordinates(*inputs, frame): def coordinate_to_quantity(*inputs, frame): results = frame.from_high_level_coordinates(*inputs) - if not isinstance(results, list): - results = [results] - return [r << unit for r, unit in zip(results, frame.unit, strict=False)] + if not isinstance(results, tuple): + results = (results,) + return tuple(r << unit for r, unit in zip(results, frame.unit, strict=False)) @pytest.mark.parametrize("inputs", inputs2) From be55069a286fbf24a424514051972644fb5c8ec0 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 18 Feb 2025 19:27:40 -0500 Subject: [PATCH 14/14] Split the rest of the coordinate frames into their own modules --- gwcs/api.py | 17 +- gwcs/coordinate_frames/__init__.py | 20 +- gwcs/coordinate_frames/_axis.py | 2 +- gwcs/coordinate_frames/_base.py | 12 +- gwcs/coordinate_frames/_celestial.py | 134 ++++ gwcs/coordinate_frames/_composite.py | 149 +++++ gwcs/coordinate_frames/_coordinate_frames.py | 618 ------------------- gwcs/coordinate_frames/_core.py | 5 +- gwcs/coordinate_frames/_empty.py | 76 +++ gwcs/coordinate_frames/_frame.py | 57 ++ gwcs/coordinate_frames/_spectral.py | 78 +++ gwcs/coordinate_frames/_stokes.py | 65 ++ gwcs/coordinate_frames/_temporal.py | 117 ++++ gwcs/coordinate_frames/_utils.py | 60 ++ gwcs/tests/test_coordinate_systems.py | 2 +- pyproject.toml | 1 - 16 files changed, 772 insertions(+), 641 deletions(-) create mode 100644 gwcs/coordinate_frames/_celestial.py create mode 100644 gwcs/coordinate_frames/_composite.py delete mode 100644 gwcs/coordinate_frames/_coordinate_frames.py create mode 100644 gwcs/coordinate_frames/_empty.py create mode 100644 gwcs/coordinate_frames/_frame.py create mode 100644 gwcs/coordinate_frames/_spectral.py create mode 100644 gwcs/coordinate_frames/_stokes.py create mode 100644 gwcs/coordinate_frames/_temporal.py create mode 100644 gwcs/coordinate_frames/_utils.py diff --git a/gwcs/api.py b/gwcs/api.py index c00c1c9e..042b28f3 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -5,6 +5,7 @@ """ +from collections.abc import Callable from typing import Any, NamedTuple, TypeAlias import astropy.units as u @@ -19,6 +20,7 @@ "WorldAxisClasses", "WorldAxisComponent", "WorldAxisComponents", + "WorldAxisConverterClass", ] @@ -32,7 +34,18 @@ class WorldAxisClass(NamedTuple): kwargs: dict[str, Any] -WorldAxisClasses: TypeAlias = dict[str | int, WorldAxisClass] +class WorldAxisConverterClass(NamedTuple): + """ + Named tuple for the world_axis_object_classes WCS property, which have a converter + """ + + object_type: type | str + args: tuple[int | None, ...] + kwargs: dict[str, Any] + converter: Callable[..., Any] | None = None + + +WorldAxisClasses: TypeAlias = dict[str | int, WorldAxisClass | WorldAxisConverterClass] class WorldAxisComponent(NamedTuple): @@ -42,7 +55,7 @@ class WorldAxisComponent(NamedTuple): name: str key: str | int - property_name: str + property_name: str | Callable[[Any], Any] WorldAxisComponents: TypeAlias = list[WorldAxisComponent] diff --git a/gwcs/coordinate_frames/__init__.py b/gwcs/coordinate_frames/__init__.py index 38654af7..a53b6412 100644 --- a/gwcs/coordinate_frames/__init__.py +++ b/gwcs/coordinate_frames/__init__.py @@ -116,20 +116,20 @@ """ +from ._axis import AxisType from ._base import BaseCoordinateFrame -from ._coordinate_frames import ( - CelestialFrame, - CompositeFrame, - EmptyFrame, - Frame2D, - SpectralFrame, - StokesFrame, - TemporalFrame, - get_ctype_from_ucd, -) +from ._celestial import CelestialFrame +from ._composite import CompositeFrame from ._core import CoordinateFrame +from ._empty import EmptyFrame +from ._frame import Frame2D +from ._spectral import SpectralFrame +from ._stokes import StokesFrame +from ._temporal import TemporalFrame +from ._utils import get_ctype_from_ucd __all__ = [ + "AxisType", "BaseCoordinateFrame", "CelestialFrame", "CompositeFrame", diff --git a/gwcs/coordinate_frames/_axis.py b/gwcs/coordinate_frames/_axis.py index f05b1c94..fbaaa33b 100644 --- a/gwcs/coordinate_frames/_axis.py +++ b/gwcs/coordinate_frames/_axis.py @@ -21,4 +21,4 @@ def __repr__(self) -> str: return self.value -AxesType: TypeAlias = tuple[AxisType, ...] | AxisType +AxesType: TypeAlias = tuple[AxisType | str, ...] | AxisType | str diff --git a/gwcs/coordinate_frames/_base.py b/gwcs/coordinate_frames/_base.py index 6120a42d..41af62ef 100644 --- a/gwcs/coordinate_frames/_base.py +++ b/gwcs/coordinate_frames/_base.py @@ -6,11 +6,9 @@ import numpy.typing as npt from astropy import units as u from astropy.coordinates import BaseCoordinateFrame as _BaseCoordinateFrame +from astropy.time import Time -from gwcs._typing import ( - AxisPhysicalTypes, - LowLevelUnitValue, -) +from gwcs._typing import AxisPhysicalTypes, LowLevelUnitValue from gwcs.api import WorldAxisClasses, WorldAxisComponent, WorldAxisComponents from ._axis import AxesType @@ -66,7 +64,7 @@ def axes_order(self) -> tuple[int, ...]: @property @abc.abstractmethod - def reference_frame(self) -> _BaseCoordinateFrame | None: + def reference_frame(self) -> _BaseCoordinateFrame | Time | None: """ The reference frame of the coordinates described by this frame. @@ -118,7 +116,9 @@ def world_axis_object_components(self) -> WorldAxisComponents: np.argsort(self.axes_order) ].tolist() - # Makes MyPy happy + # Unpack the arguments passed by the map into the WorldAxisComponent NamedTuple + # NamedTuple apparently does not automatically unpack the arguments like a tuple + # will def _func(arg: Sequence[Any]) -> WorldAxisComponent: return WorldAxisComponent(*arg) diff --git a/gwcs/coordinate_frames/_celestial.py b/gwcs/coordinate_frames/_celestial.py new file mode 100644 index 00000000..f68383bc --- /dev/null +++ b/gwcs/coordinate_frames/_celestial.py @@ -0,0 +1,134 @@ +from astropy import units as u +from astropy.coordinates import ( + GCRS, + BaseCoordinateFrame, + Galactic, + GeocentricTrueEcliptic, + PrecessedGeocentric, + SkyCoord, + builtin_frames, +) + +from gwcs._typing import AxisPhysicalTypes +from gwcs.api import ( + WorldAxisClass, + WorldAxisClasses, + WorldAxisComponent, + WorldAxisComponents, +) + +from ._axis import AxisType +from ._core import CoordinateFrame +from ._properties import FrameProperties + +__all__ = ["CelestialFrame"] + +STANDARD_REFERENCE_FRAMES = [frame.upper() for frame in builtin_frames.__all__] + + +class CelestialFrame(CoordinateFrame): + """ + Representation of a Celesital coordinate system. + + This class has a native order of longitude then latitude, meaning + ``axes_names``, ``unit`` and ``axis_physical_types`` should be lon, lat + ordered. If your transform is in a different order this should be specified + with ``axes_order``. + + Parameters + ---------- + axes_order + A dimension in the input data that corresponds to this axis. + reference_frame + A reference frame. + unit + Units on axes. + axes_names + Names of the axes in this frame. + name + Name of this frame. + axis_physical_types + The UCD 1+ physical types for the axes, in frame order (lon, lat). + """ + + def __init__( + self, + axes_order: tuple[int, ...] | None = None, + reference_frame: BaseCoordinateFrame | None = None, + unit: tuple[u.Unit, ...] | None = None, + axes_names: tuple[str, ...] | None = None, + name: str | None = None, + axis_physical_types: AxisPhysicalTypes | None = None, + ) -> None: + naxes = 2 + if ( + reference_frame is not None + and not isinstance(reference_frame, str) + and reference_frame.name.upper() in STANDARD_REFERENCE_FRAMES + ): + _axes_names = ( + tuple( + n + for n in reference_frame.representation_component_names.values() + if n != "distance" + ) + if axes_names is None + else axes_names + ) + naxes = len(_axes_names) + + self.native_axes_order = tuple(range(naxes)) + if axes_order is None: + axes_order = self.native_axes_order + if unit is None: + # Astropy dynamically creates some units, so MyPy can't find them + unit = tuple([u.degree] * naxes) # type: ignore[attr-defined] + axes_type = (AxisType.SPATIAL,) * naxes + + super().__init__( + naxes=naxes, + axes_type=axes_type, + axes_order=axes_order, + reference_frame=reference_frame, + unit=unit, + axes_names=_axes_names, + name=name, + axis_physical_types=axis_physical_types, + ) + + def _default_axis_physical_types( + self, properties: FrameProperties + ) -> AxisPhysicalTypes: + if isinstance(self.reference_frame, Galactic): + return "pos.galactic.lon", "pos.galactic.lat" + if isinstance( + self.reference_frame, + GeocentricTrueEcliptic | GCRS | PrecessedGeocentric, + ): + return "pos.bodyrc.lon", "pos.bodyrc.lat" + if isinstance(self.reference_frame, builtin_frames.BaseRADecFrame): + return "pos.eq.ra", "pos.eq.dec" + if isinstance(self.reference_frame, builtin_frames.BaseEclipticFrame): + return "pos.ecliptic.lon", "pos.ecliptic.lat" + return tuple(f"custom:{t}" for t in properties.axes_names) + + @property + def world_axis_object_classes(self) -> WorldAxisClasses: + return { + "celestial": WorldAxisClass( + SkyCoord, + (), + {"frame": self.reference_frame, "unit": self._prop.unit}, + ) + } + + @property + def _native_world_axis_object_components(self) -> WorldAxisComponents: + return [ + WorldAxisComponent( + "celestial", 0, lambda sc: sc.spherical.lon.to_value(self._prop.unit[0]) + ), + WorldAxisComponent( + "celestial", 1, lambda sc: sc.spherical.lat.to_value(self._prop.unit[1]) + ), + ] diff --git a/gwcs/coordinate_frames/_composite.py b/gwcs/coordinate_frames/_composite.py new file mode 100644 index 00000000..e9f35d46 --- /dev/null +++ b/gwcs/coordinate_frames/_composite.py @@ -0,0 +1,149 @@ +from collections import defaultdict +from collections.abc import Generator +from typing import Any, cast + +import numpy as np +from astropy import units as u + +from gwcs._typing import AxisPhysicalType +from gwcs.api import ( + WorldAxisClass, + WorldAxisClasses, + WorldAxisComponent, + WorldAxisComponents, + WorldAxisConverterClass, +) + +from ._axis import AxisType +from ._base import BaseCoordinateFrame +from ._core import CoordinateFrame + +__all__ = ["CompositeFrame"] + + +class CompositeFrame(CoordinateFrame): + """ + Represents one or more frames. + + Parameters + ---------- + frames + List of constituient frames. + name + Name for this frame. + """ + + def __init__( + self, frames: list[BaseCoordinateFrame], name: str | None = None + ) -> None: + self._frames = frames[:] + naxes = sum([frame.naxes for frame in self._frames]) + + axes_order: list[int] = [] + axes_type: list[AxisType | str] = [] + axes_names: list[str] = [] + unit: list[u.Unit] = [] + ph_type: list[AxisPhysicalType] = [] + + for frame in frames: + axes_order.extend(frame.axes_order) + + # Stack the raw (not-native) ordered properties + for frame in frames: + axes_type += list(frame._prop.axes_type) + axes_names += list(frame._prop.axes_names) + # no common base class in astropy.units for all units + unit += list(frame._prop.unit) # type: ignore[arg-type] + ph_type += list(frame._prop.axis_physical_types) + + if len(np.unique(axes_order)) != len(axes_order): + msg = ( + "Incorrect numbering of axes, " + "axes_order should contain unique numbers, " + f"got {axes_order}." + ) + raise ValueError(msg) + + super().__init__( + naxes, + axes_type=tuple(axes_type), + axes_order=tuple(axes_order), + unit=tuple(unit), + axes_names=tuple(axes_names), + axis_physical_types=tuple(ph_type), + name=name, + ) + self._axis_physical_types = tuple(ph_type) + + @property + def frames(self) -> list[BaseCoordinateFrame]: + """ + The constituient frames that comprise this `CompositeFrame`. + """ + return self._frames + + def __repr__(self) -> str: + return repr(self.frames) + + @property + def _wao_classes_rename_map(self) -> dict[Any, Any]: + mapper: dict[Any, Any] = defaultdict(dict) + seen_names: list[str] = [] + for frame in self.frames: + # ensure the frame is in the mapper + mapper[frame] + for key in frame.world_axis_object_classes: + key = cast(str, key) + if key in seen_names: + new_key = f"{key}{seen_names.count(key)}" + mapper[frame][key] = new_key + seen_names.append(key) + return mapper + + @property + def _wao_renamed_components_iter( + self, + ) -> Generator[tuple[BaseCoordinateFrame, list[WorldAxisComponent]], None, None]: + mapper = self._wao_classes_rename_map + for frame in self.frames: + renamed_components: list[WorldAxisComponent] = [] + for component in frame._native_world_axis_object_components: + rename: str = mapper[frame].get(component[0], component[0]) + renamed_components.append( + WorldAxisComponent(rename, component[1], component[2]) + ) + + yield frame, renamed_components + + @property + def _wao_renamed_classes_iter( + self, + ) -> Generator[tuple[str, WorldAxisClass | WorldAxisConverterClass], None, None]: + mapper = self._wao_classes_rename_map + for frame in self.frames: + for key, value in frame.world_axis_object_classes.items(): + rename: str = mapper[frame].get(key, key) + yield rename, value + + @property + def world_axis_object_components(self) -> WorldAxisComponents: + """ + Object components for this frame. + """ + out: list[WorldAxisComponent | None] = [None] * self.naxes + + for frame, components in self._wao_renamed_components_iter: + for i, ao in enumerate(frame.axes_order): + out[ao] = components[i] + + if any(o is None for o in out): + msg = "axes_order leads to incomplete world_axis_object_components" + raise ValueError(msg) + + # There can be None in the list here, but this is unique to this and + # annoying otherwise, so we ignore MyPy + return out # type: ignore[return-value] + + @property + def world_axis_object_classes(self) -> WorldAxisClasses: + return dict(self._wao_renamed_classes_iter) diff --git a/gwcs/coordinate_frames/_coordinate_frames.py b/gwcs/coordinate_frames/_coordinate_frames.py deleted file mode 100644 index f98ed26d..00000000 --- a/gwcs/coordinate_frames/_coordinate_frames.py +++ /dev/null @@ -1,618 +0,0 @@ -# Licensed under a 3-clause BSD style license - see LICENSE.rst - -import contextlib -import logging -from collections import defaultdict - -import numpy as np -from astropy import coordinates as coord -from astropy import time -from astropy import units as u -from astropy.coordinates import StokesCoord -from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1 - -from ._core import CoordinateFrame - -__all__ = [ - "CelestialFrame", - "CompositeFrame", - "EmptyFrame", - "Frame2D", - "SpectralFrame", - "StokesFrame", - "TemporalFrame", - "get_ctype_from_ucd", -] - - -def _ucd1_to_ctype_name_mapping(ctype_to_ucd, allowed_ucd_duplicates): - inv_map = {} - new_ucd = set() - - for kwd, ucd in ctype_to_ucd.items(): - if ucd in inv_map: - if ucd not in allowed_ucd_duplicates: - new_ucd.add(ucd) - continue - inv_map[ucd] = allowed_ucd_duplicates.get(ucd, kwd) - - if new_ucd: - logging.warning( - "Found unsupported duplicate physical type in 'astropy' mapping to CTYPE.\n" - "Update 'gwcs' to the latest version or notify 'gwcs' developer.\n" - "Duplicate physical types will be mapped to the following CTYPEs:\n" - + "\n".join([f"{ucd!r:s} --> {inv_map[ucd]!r:s}" for ucd in new_ucd]) - ) - - return inv_map - - -# List below allowed physical type duplicates and a corresponding CTYPE -# to which all duplicates will be mapped to: -_ALLOWED_UCD_DUPLICATES = { - "time": "TIME", - "em.wl": "WAVE", -} - -UCD1_TO_CTYPE = _ucd1_to_ctype_name_mapping( - ctype_to_ucd=CTYPE_TO_UCD1, allowed_ucd_duplicates=_ALLOWED_UCD_DUPLICATES -) - -STANDARD_REFERENCE_FRAMES = [frame.upper() for frame in coord.builtin_frames.__all__] - - -def get_ctype_from_ucd(ucd): - """ - Return the FITS ``CTYPE`` corresponding to a UCD1 value. - - Parameters - ---------- - ucd : str - UCD string, for example one of ```WCS.world_axis_physical_types``. - - Returns - ------- - CTYPE : str - The corresponding FITS ``CTYPE`` value or an empty string. - """ - return UCD1_TO_CTYPE.get(ucd, "") - - -class EmptyFrame(CoordinateFrame): - """ - Represents a "default" detector frame. This is for use as the default value - for input frame by the WCS object. - """ - - def __init__(self, name=None): - self._name = "detector" if name is None else name - - def __repr__(self): - return f'<{type(self).__name__}(name="{self.name}")>' - - def __str__(self): - if self._name is not None: - return self._name - return type(self).__name__ - - @property - def name(self): - """A custom name of this frame.""" - return self._name - - @name.setter - def name(self, val): - """A custom name of this frame.""" - self._name = val - - def _raise_error(self) -> None: - msg = "EmptyFrame does not have any information" - raise NotImplementedError(msg) - - @property - def naxes(self): - self._raise_error() - - @property - def unit(self): - self._raise_error() - - @property - def axes_names(self): - self._raise_error() - - @property - def axes_order(self): - self._raise_error() - - @property - def reference_frame(self): - self._raise_error() - - @property - def axes_type(self): - self._raise_error() - - @property - def axis_physical_types(self): - self._raise_error() - - @property - def world_axis_object_classes(self): - self._raise_error() - - @property - def _native_world_axis_object_components(self): - self._raise_error() - - def to_high_level_coordinates(self, *values): - self._raise_error() - - def from_high_level_coordinates(self, *high_level_coords): - self._raise_error() - - -class CelestialFrame(CoordinateFrame): - """ - Representation of a Celesital coordinate system. - - This class has a native order of longitude then latitude, meaning - ``axes_names``, ``unit`` and ``axis_physical_types`` should be lon, lat - ordered. If your transform is in a different order this should be specified - with ``axes_order``. - - Parameters - ---------- - axes_order : tuple of int - A dimension in the input data that corresponds to this axis. - reference_frame : astropy.coordinates.builtin_frames - A reference frame. - unit : str or units.Unit instance or iterable of those - Units on axes. - axes_names : list - Names of the axes in this frame. - name : str - Name of this frame. - axis_physical_types : list - The UCD 1+ physical types for the axes, in frame order (lon, lat). - """ - - def __init__( - self, - axes_order=None, - reference_frame=None, - unit=None, - axes_names=None, - name=None, - axis_physical_types=None, - ): - naxes = 2 - if ( - reference_frame is not None - and not isinstance(reference_frame, str) - and reference_frame.name.upper() in STANDARD_REFERENCE_FRAMES - ): - _axes_names = ( - tuple( - n - for n in reference_frame.representation_component_names.values() - if n != "distance" - ) - if axes_names is None - else axes_names - ) - naxes = len(_axes_names) - - self.native_axes_order = tuple(range(naxes)) - if axes_order is None: - axes_order = self.native_axes_order - if unit is None: - unit = tuple([u.degree] * naxes) - axes_type = ["SPATIAL"] * naxes - - super().__init__( - naxes=naxes, - axes_type=axes_type, - axes_order=axes_order, - reference_frame=reference_frame, - unit=unit, - axes_names=_axes_names, - name=name, - axis_physical_types=axis_physical_types, - ) - - def _default_axis_physical_types(self, properties): - if isinstance(self.reference_frame, coord.Galactic): - return "pos.galactic.lon", "pos.galactic.lat" - if isinstance( - self.reference_frame, - coord.GeocentricTrueEcliptic | coord.GCRS | coord.PrecessedGeocentric, - ): - return "pos.bodyrc.lon", "pos.bodyrc.lat" - if isinstance(self.reference_frame, coord.builtin_frames.BaseRADecFrame): - return "pos.eq.ra", "pos.eq.dec" - if isinstance(self.reference_frame, coord.builtin_frames.BaseEclipticFrame): - return "pos.ecliptic.lon", "pos.ecliptic.lat" - return tuple(f"custom:{t}" for t in properties.axes_names) - - @property - def world_axis_object_classes(self): - return { - "celestial": ( - coord.SkyCoord, - (), - {"frame": self.reference_frame, "unit": self._prop.unit}, - ) - } - - @property - def _native_world_axis_object_components(self): - return [ - ("celestial", 0, lambda sc: sc.spherical.lon.to_value(self._prop.unit[0])), - ("celestial", 1, lambda sc: sc.spherical.lat.to_value(self._prop.unit[1])), - ] - - -class SpectralFrame(CoordinateFrame): - """ - Represents Spectral Frame - - Parameters - ---------- - axes_order : tuple or int - A dimension in the input data that corresponds to this axis. - reference_frame : astropy.coordinates.builtin_frames - Reference frame (usually used with output_frame to convert to world - coordinate objects). - unit : str or units.Unit instance - Spectral unit. - axes_names : str - Spectral axis name. - name : str - Name for this frame. - - """ - - def __init__( - self, - axes_order=(0,), - reference_frame=None, - unit=None, - axes_names=None, - name=None, - axis_physical_types=None, - ): - super().__init__( - naxes=1, - axes_type="SPECTRAL", - axes_order=axes_order, - axes_names=axes_names, - reference_frame=reference_frame, - unit=unit, - name=name, - axis_physical_types=axis_physical_types, - ) - - def _default_axis_physical_types(self, properties): - if properties.unit[0].physical_type == "frequency": - return ("em.freq",) - if properties.unit[0].physical_type == "length": - return ("em.wl",) - if properties.unit[0].physical_type == "energy": - return ("em.energy",) - if properties.unit[0].physical_type == "speed": - return ("spect.dopplerVeloc",) - logging.warning( - "Physical type may be ambiguous. Consider " - "setting the physical type explicitly as " - "either 'spect.dopplerVeloc.optical' or " - "'spect.dopplerVeloc.radio'." - ) - return (f"custom:{properties.unit[0].physical_type}",) - - @property - def world_axis_object_classes(self): - return {"spectral": (coord.SpectralCoord, (), {"unit": self.unit[0]})} - - @property - def _native_world_axis_object_components(self): - return [("spectral", 0, lambda sc: sc.to_value(self.unit[0]))] - - -class TemporalFrame(CoordinateFrame): - """ - A coordinate frame for time axes. - - Parameters - ---------- - reference_frame : `~astropy.time.Time` - A Time object which holds the time scale and format. - If data is provided, it is the time zero point. - To not set a zero point for the frame initialize ``reference_frame`` - with an empty list. - unit : str or `~astropy.units.Unit` - Time unit. - axes_names : str - Time axis name. - axes_order : tuple or int - A dimension in the data that corresponds to this axis. - name : str - Name for this frame. - """ - - def __init__( - self, - reference_frame, - unit=u.s, - axes_order=(0,), - axes_names=None, - name=None, - axis_physical_types=None, - ): - _axes_names = ( - ( - f"{reference_frame.format}({reference_frame.scale}; " - f"{reference_frame.location}", - ) - if axes_names is None - else axes_names - ) - - super().__init__( - naxes=1, - axes_type="TIME", - axes_order=axes_order, - axes_names=_axes_names, - reference_frame=reference_frame, - unit=unit, - name=name, - axis_physical_types=axis_physical_types, - ) - self._attrs = {} - for a in self.reference_frame.info._represent_as_dict_extra_attrs: - with contextlib.suppress(AttributeError): - self._attrs[a] = getattr(self.reference_frame, a) - - def _default_axis_physical_types(self, properties): - return ("time",) - - def _convert_to_time(self, dt, *, unit, **kwargs): - if ( - not isinstance(dt, time.TimeDelta) and isinstance(dt, time.Time) - ) or isinstance(self.reference_frame.value, np.ndarray): - return time.Time(dt, **kwargs) - - if not hasattr(dt, "unit"): - dt = dt * unit - - return self.reference_frame + dt - - @property - def world_axis_object_classes(self): - comp = ( - time.Time, - (), - {"unit": self.unit[0], **self._attrs}, - self._convert_to_time, - ) - - return {"temporal": comp} - - @property - def _native_world_axis_object_components(self): - if isinstance(self.reference_frame.value, np.ndarray): - return [("temporal", 0, "value")] - - def offset_from_time_and_reference(time): - return (time - self.reference_frame).sec - - return [("temporal", 0, offset_from_time_and_reference)] - - -class CompositeFrame(CoordinateFrame): - """ - Represents one or more frames. - - Parameters - ---------- - frames : list - List of constituient frames. - name : str - Name for this frame. - """ - - def __init__(self, frames, name=None): - self._frames = frames[:] - naxes = sum([frame._naxes for frame in self._frames]) - - axes_order = [] - axes_type = [] - axes_names = [] - unit = [] - ph_type = [] - - for frame in frames: - axes_order.extend(frame.axes_order) - - # Stack the raw (not-native) ordered properties - for frame in frames: - axes_type += list(frame._prop.axes_type) - axes_names += list(frame._prop.axes_names) - unit += list(frame._prop.unit) - ph_type += list(frame._prop.axis_physical_types) - - if len(np.unique(axes_order)) != len(axes_order): - msg = ( - "Incorrect numbering of axes, " - "axes_order should contain unique numbers, " - f"got {axes_order}." - ) - raise ValueError(msg) - - super().__init__( - naxes, - axes_type=tuple(axes_type), - axes_order=tuple(axes_order), - unit=tuple(unit), - axes_names=tuple(axes_names), - axis_physical_types=tuple(ph_type), - name=name, - ) - self._axis_physical_types = tuple(ph_type) - - @property - def frames(self): - """ - The constituient frames that comprise this `CompositeFrame`. - """ - return self._frames - - def __repr__(self): - return repr(self.frames) - - @property - def _wao_classes_rename_map(self): - mapper = defaultdict(dict) - seen_names = [] - for frame in self.frames: - # ensure the frame is in the mapper - mapper[frame] - for key in frame.world_axis_object_classes: - if key in seen_names: - new_key = f"{key}{seen_names.count(key)}" - mapper[frame][key] = new_key - seen_names.append(key) - return mapper - - @property - def _wao_renamed_components_iter(self): - mapper = self._wao_classes_rename_map - for frame in self.frames: - renamed_components = [] - for component in frame._native_world_axis_object_components: - comp = list(component) - rename = mapper[frame].get(comp[0]) - if rename: - comp[0] = rename - renamed_components.append(tuple(comp)) - yield frame, renamed_components - - @property - def _wao_renamed_classes_iter(self): - mapper = self._wao_classes_rename_map - for frame in self.frames: - for key, value in frame.world_axis_object_classes.items(): - rename = mapper[frame].get(key) - yield rename if rename else key, value - - @property - def world_axis_object_components(self): - out = [None] * self.naxes - - for frame, components in self._wao_renamed_components_iter: - for i, ao in enumerate(frame.axes_order): - out[ao] = components[i] - - if any(o is None for o in out): - msg = "axes_order leads to incomplete world_axis_object_components" - raise ValueError(msg) - - return out - - @property - def world_axis_object_classes(self): - return dict(self._wao_renamed_classes_iter) - - -class StokesFrame(CoordinateFrame): - """ - A coordinate frame for representing Stokes polarisation states. - - Parameters - ---------- - name : str - Name of this frame. - axes_order : tuple - A dimension in the data that corresponds to this axis. - """ - - def __init__( - self, - axes_order=(0,), - axes_names=("stokes",), - name=None, - axis_physical_types=None, - ): - super().__init__( - 1, - ["STOKES"], - axes_order, - name=name, - axes_names=axes_names, - unit=u.one, - axis_physical_types=axis_physical_types, - ) - - def _default_axis_physical_types(self, properties): - return ("phys.polarization.stokes",) - - @property - def world_axis_object_classes(self): - return { - "stokes": ( - StokesCoord, - (), - {}, - ) - } - - @property - def _native_world_axis_object_components(self): - return [("stokes", 0, "value")] - - -class Frame2D(CoordinateFrame): - """ - A 2D coordinate frame. - - Parameters - ---------- - axes_order : tuple of int - A dimension in the input data that corresponds to this axis. - unit : list of astropy.units.Unit - Unit for each axis. - axes_names : list - Names of the axes in this frame. - name : str - Name of this frame. - """ - - def __init__( - self, - axes_order=(0, 1), - unit=(u.pix, u.pix), - axes_names=("x", "y"), - name=None, - axes_type=None, - axis_physical_types=None, - ): - if axes_type is None: - axes_type = ["SPATIAL", "SPATIAL"] - - super().__init__( - naxes=2, - axes_type=axes_type, - axes_order=axes_order, - name=name, - axes_names=axes_names, - unit=unit, - axis_physical_types=axis_physical_types, - ) - - def _default_axis_physical_types(self, properties): - if properties.axes_names is not None and all(properties.axes_names): - ph_type = properties.axes_names - else: - ph_type = properties.axes_type - - return tuple(f"custom:{t}" for t in ph_type) diff --git a/gwcs/coordinate_frames/_core.py b/gwcs/coordinate_frames/_core.py index 85b1984c..8c4fb799 100644 --- a/gwcs/coordinate_frames/_core.py +++ b/gwcs/coordinate_frames/_core.py @@ -4,6 +4,7 @@ import numpy as np from astropy import units as u from astropy.coordinates import BaseCoordinateFrame as _BaseCoordinateFrame +from astropy.time import Time from astropy.wcs.wcsapi.high_level_api import ( high_level_objects_to_values, values_to_high_level_objects, @@ -60,7 +61,7 @@ def __init__( naxes: int, axes_type: AxesType, axes_order: tuple[int, ...], - reference_frame: _BaseCoordinateFrame | None = None, + reference_frame: _BaseCoordinateFrame | Time | None = None, unit: tuple[u.Unit, ...] | None = None, axes_names: tuple[str, ...] | None = None, name: str | None = None, @@ -152,7 +153,7 @@ def axes_order(self) -> tuple[int, ...]: return self._axes_order @property - def reference_frame(self) -> _BaseCoordinateFrame | None: + def reference_frame(self) -> _BaseCoordinateFrame | Time | None: """Reference frame, used to convert to world coordinate objects.""" return self._reference_frame diff --git a/gwcs/coordinate_frames/_empty.py b/gwcs/coordinate_frames/_empty.py new file mode 100644 index 00000000..9e5e2101 --- /dev/null +++ b/gwcs/coordinate_frames/_empty.py @@ -0,0 +1,76 @@ +from astropy import units as u +from astropy.coordinates import BaseCoordinateFrame as _BaseCoordinateFrame + +from gwcs._typing import AxisPhysicalTypes +from gwcs.api import WorldAxisClasses, WorldAxisComponents + +from ._axis import AxesType +from ._core import CoordinateFrame + +__all__ = ["EmptyFrame"] + + +class EmptyFrame(CoordinateFrame): + """ + Represents a "default" detector frame. This is for use as the default value + for input frame by the WCS object. + """ + + def __init__(self, name: str | None = None) -> None: + self._name = "detector" if name is None else name + + def __repr__(self) -> str: + return f'<{type(self).__name__}(name="{self.name}")>' + + def __str__(self) -> str: + return self._name + + @property + def name(self) -> str: + """A custom name of this frame.""" + return self._name + + @name.setter + def name(self, val: str) -> None: + """A custom name of this frame.""" + self._name = val + + def _raise_error(self) -> None: + msg = "EmptyFrame does not have any information" + raise NotImplementedError(msg) + + @property + def naxes(self) -> int: # type: ignore[return] + self._raise_error() + + @property + def unit(self) -> tuple[u.Unit, ...]: # type: ignore[return] + self._raise_error() + + @property + def axes_names(self) -> tuple[str, ...]: # type: ignore[return] + self._raise_error() + + @property + def axes_order(self) -> tuple[int, ...]: # type: ignore[return] + self._raise_error() + + @property + def reference_frame(self) -> _BaseCoordinateFrame | None: # type: ignore[return] + self._raise_error() + + @property + def axes_type(self) -> AxesType: # type: ignore[return] + self._raise_error() + + @property + def axis_physical_types(self) -> AxisPhysicalTypes: # type: ignore[return] + self._raise_error() + + @property + def world_axis_object_classes(self) -> WorldAxisClasses: # type: ignore[return] + self._raise_error() + + @property + def _native_world_axis_object_components(self) -> WorldAxisComponents: # type: ignore[return] + self._raise_error() diff --git a/gwcs/coordinate_frames/_frame.py b/gwcs/coordinate_frames/_frame.py new file mode 100644 index 00000000..98fdbdc0 --- /dev/null +++ b/gwcs/coordinate_frames/_frame.py @@ -0,0 +1,57 @@ +from astropy import units as u + +from gwcs._typing import AxisPhysicalTypes + +from ._axis import AxesType, AxisType +from ._core import CoordinateFrame +from ._properties import FrameProperties + + +class Frame2D(CoordinateFrame): + """ + A 2D coordinate frame. + + Parameters + ---------- + axes_order + A dimension in the input data that corresponds to this axis. + unit + Unit for each axis. + axes_name + Names of the axes in this frame. + name + Name of this frame. + """ + + def __init__( + self, + axes_order: tuple[int, ...] = (0, 1), + # Astropy dynamically builds these types at runtime, so MyPy can't find them + unit: tuple[u.Unit, ...] = (u.pix, u.pix), # type: ignore[attr-defined] + axes_names: tuple[str, ...] = ("x", "y"), + name: str | None = None, + axes_type: AxesType | None = None, + axis_physical_types: AxisPhysicalTypes | None = None, + ) -> None: + if axes_type is None: + axes_type = (AxisType.SPATIAL, AxisType.SPATIAL) + + super().__init__( + naxes=2, + axes_type=axes_type, + axes_order=axes_order, + name=name, + axes_names=axes_names, + unit=unit, + axis_physical_types=axis_physical_types, + ) + + def _default_axis_physical_types( + self, properties: FrameProperties + ) -> AxisPhysicalTypes: + if all(properties.axes_names): + ph_type = properties.axes_names + else: + ph_type = properties.axes_type + + return tuple(f"custom:{t}" for t in ph_type) diff --git a/gwcs/coordinate_frames/_spectral.py b/gwcs/coordinate_frames/_spectral.py new file mode 100644 index 00000000..8dcce4c7 --- /dev/null +++ b/gwcs/coordinate_frames/_spectral.py @@ -0,0 +1,78 @@ +from astropy import units as u +from astropy.coordinates import BaseCoordinateFrame, SpectralCoord + +from gwcs._typing import AxisPhysicalTypes +from gwcs.api import ( + WorldAxisClass, + WorldAxisClasses, + WorldAxisComponent, + WorldAxisComponents, +) + +from ._axis import AxisType +from ._core import CoordinateFrame +from ._properties import FrameProperties + +__all__ = ["SpectralFrame"] + + +class SpectralFrame(CoordinateFrame): + """ + Represents Spectral Frame + + Parameters + ---------- + axes_order + A dimension in the input data that corresponds to this axis. + reference_frame : astropy.coordinates.builtin_frames + Reference frame (usually used with output_frame to convert to world + coordinate objects). + unit + Spectral unit. + axes_names + Spectral axis name. + name + Name for this frame. + + """ + + def __init__( + self, + axes_order: tuple[int, ...] = (0,), + reference_frame: BaseCoordinateFrame | None = None, + unit: tuple[u.Unit, ...] | None = None, + axes_names: tuple[str, ...] | None = None, + name: str | None = None, + axis_physical_types: AxisPhysicalTypes | None = None, + ) -> None: + super().__init__( + naxes=1, + axes_type=AxisType.SPECTRAL, + axes_order=axes_order, + axes_names=axes_names, + reference_frame=reference_frame, + unit=unit, + name=name, + axis_physical_types=axis_physical_types, + ) + + def _default_axis_physical_types( + self, properties: FrameProperties + ) -> AxisPhysicalTypes: + if properties.unit[0].physical_type == "frequency": + return ("em.freq",) + if properties.unit[0].physical_type == "length": + return ("em.wl",) + if properties.unit[0].physical_type == "energy": + return ("em.energy",) + if properties.unit[0].physical_type == "speed": + return ("spect.dopplerVeloc",) + return (f"custom:{properties.unit[0].physical_type}",) + + @property + def world_axis_object_classes(self) -> WorldAxisClasses: + return {"spectral": WorldAxisClass(SpectralCoord, (), {"unit": self.unit[0]})} + + @property + def _native_world_axis_object_components(self) -> WorldAxisComponents: + return [WorldAxisComponent("spectral", 0, lambda sc: sc.to_value(self.unit[0]))] diff --git a/gwcs/coordinate_frames/_stokes.py b/gwcs/coordinate_frames/_stokes.py new file mode 100644 index 00000000..20aa9f78 --- /dev/null +++ b/gwcs/coordinate_frames/_stokes.py @@ -0,0 +1,65 @@ +from astropy import units as u +from astropy.coordinates import StokesCoord + +from gwcs._typing import AxisPhysicalTypes +from gwcs.api import ( + WorldAxisClass, + WorldAxisClasses, + WorldAxisComponent, + WorldAxisComponents, +) + +from ._axis import AxisType +from ._core import CoordinateFrame +from ._properties import FrameProperties + +__all__ = ["StokesFrame"] + + +class StokesFrame(CoordinateFrame): + """ + A coordinate frame for representing Stokes polarisation states. + + Parameters + ---------- + name + Name of this frame. + axes_order + A dimension in the data that corresponds to this axis. + """ + + def __init__( + self, + axes_order: tuple[int, ...] = (0,), + axes_names: tuple[str, ...] = ("stokes",), + name: str | None = None, + axis_physical_types: AxisPhysicalTypes | None = None, + ) -> None: + super().__init__( + 1, + (AxisType.STOKES,), + axes_order, + name=name, + axes_names=axes_names, + unit=(u.one,), # type: ignore[arg-type] + axis_physical_types=axis_physical_types, + ) + + def _default_axis_physical_types( + self, properties: FrameProperties + ) -> AxisPhysicalTypes: + return ("phys.polarization.stokes",) + + @property + def world_axis_object_classes(self) -> WorldAxisClasses: + return { + "stokes": WorldAxisClass( + StokesCoord, + (), + {}, + ) + } + + @property + def _native_world_axis_object_components(self) -> WorldAxisComponents: + return [WorldAxisComponent("stokes", 0, "value")] diff --git a/gwcs/coordinate_frames/_temporal.py b/gwcs/coordinate_frames/_temporal.py new file mode 100644 index 00000000..3b527151 --- /dev/null +++ b/gwcs/coordinate_frames/_temporal.py @@ -0,0 +1,117 @@ +import contextlib +from typing import Any, cast + +import numpy as np +from astropy import units as u +from astropy.time import Time, TimeDelta + +from gwcs._typing import AxisPhysicalTypes +from gwcs.api import ( + WorldAxisClasses, + WorldAxisComponent, + WorldAxisComponents, + WorldAxisConverterClass, +) + +from ._axis import AxisType +from ._core import CoordinateFrame +from ._properties import FrameProperties + +__all__ = ["TemporalFrame"] + + +class TemporalFrame(CoordinateFrame): + """ + A coordinate frame for time axes. + + Parameters + ---------- + reference_frame + A Time object which holds the time scale and format. + If data is provided, it is the time zero point. + To not set a zero point for the frame initialize ``reference_frame`` + with an empty list. + unit + Time unit. + axes_names + Time axis name. + axes_order + A dimension in the data that corresponds to this axis. + name + Name for this frame. + """ + + def __init__( + self, + reference_frame: Time, + # Astropy dynamically builds these types at runtime, so MyPy can't find them + unit: tuple[u.Unit, ...] = (u.s,), # type: ignore[attr-defined] + axes_order: tuple[int, ...] = (0,), + axes_names: tuple[str, ...] | None = None, + name: str | None = None, + axis_physical_types: AxisPhysicalTypes | None = None, + ) -> None: + _axes_names = ( + ( + f"{reference_frame.format}({reference_frame.scale}; " + f"{reference_frame.location}", + ) + if axes_names is None + else axes_names + ) + + super().__init__( + naxes=1, + axes_type=AxisType.TIME, + axes_order=axes_order, + axes_names=_axes_names, + reference_frame=reference_frame, + unit=unit, + name=name, + axis_physical_types=axis_physical_types, + ) + self._attrs = {} + for a in self.reference_frame.info._represent_as_dict_extra_attrs: + with contextlib.suppress(AttributeError): + self._attrs[a] = getattr(self.reference_frame, a) + + @property + def reference_frame(self) -> Time: + return cast(Time, self._reference_frame) + + def _default_axis_physical_types( + self, properties: FrameProperties + ) -> AxisPhysicalTypes: + return ("time",) + + def _convert_to_time(self, dt: Any, *, unit: u.Unit, **kwargs: Any) -> Time: + if (not isinstance(dt, TimeDelta) and isinstance(dt, Time)) or isinstance( + self.reference_frame.value, np.ndarray + ): + return Time(dt, **kwargs) # type: ignore[no-untyped-call] + + if not hasattr(dt, "unit"): + dt = dt * unit + + return cast(Time, self.reference_frame + dt) + + @property + def world_axis_object_classes(self) -> WorldAxisClasses: + return { + "temporal": WorldAxisConverterClass( + Time, + (), + {"unit": self.unit[0], **self._attrs}, + self._convert_to_time, + ) + } + + @property + def _native_world_axis_object_components(self) -> WorldAxisComponents: + if isinstance(self.reference_frame.value, np.ndarray): + return [WorldAxisComponent("temporal", 0, "value")] + + def offset_from_time_and_reference(time: Time) -> Any: + return (time - self.reference_frame).sec + + return [WorldAxisComponent("temporal", 0, offset_from_time_and_reference)] diff --git a/gwcs/coordinate_frames/_utils.py b/gwcs/coordinate_frames/_utils.py new file mode 100644 index 00000000..2b9c71c1 --- /dev/null +++ b/gwcs/coordinate_frames/_utils.py @@ -0,0 +1,60 @@ +# Licensed under a 3-clause BSD style license - see LICENSE.rst + +import logging + +from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1 + +__all__ = ["get_ctype_from_ucd"] + + +def _ucd1_to_ctype_name_mapping( + ctype_to_ucd: dict[str, str], allowed_ucd_duplicates: dict[str, str] +) -> dict[str, str]: + inv_map = {} + new_ucd = set() + + for kwd, ucd in ctype_to_ucd.items(): + if ucd in inv_map: + if ucd not in allowed_ucd_duplicates: + new_ucd.add(ucd) + continue + inv_map[ucd] = allowed_ucd_duplicates.get(ucd, kwd) + + if new_ucd: + logging.warning( + "Found unsupported duplicate physical type in 'astropy' mapping to CTYPE.\n" + "Update 'gwcs' to the latest version or notify 'gwcs' developer.\n" + "Duplicate physical types will be mapped to the following CTYPEs:\n" + + "\n".join([f"{ucd!r:s} --> {inv_map[ucd]!r:s}" for ucd in new_ucd]) + ) + + return inv_map + + +# List below allowed physical type duplicates and a corresponding CTYPE +# to which all duplicates will be mapped to: +_ALLOWED_UCD_DUPLICATES = { + "time": "TIME", + "em.wl": "WAVE", +} + +UCD1_TO_CTYPE = _ucd1_to_ctype_name_mapping( + ctype_to_ucd=CTYPE_TO_UCD1, allowed_ucd_duplicates=_ALLOWED_UCD_DUPLICATES +) + + +def get_ctype_from_ucd(ucd: str) -> str: + """ + Return the FITS ``CTYPE`` corresponding to a UCD1 value. + + Parameters + ---------- + ucd : str + UCD string, for example one of ```WCS.world_axis_physical_types``. + + Returns + ------- + CTYPE : str + The corresponding FITS ``CTYPE`` value or an empty string. + """ + return UCD1_TO_CTYPE.get(ucd, "") diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index cfcad311..96948638 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -16,7 +16,7 @@ from gwcs import WCS from gwcs import coordinate_frames as cf -from gwcs.coordinate_frames._coordinate_frames import ( +from gwcs.coordinate_frames._utils import ( _ALLOWED_UCD_DUPLICATES, _ucd1_to_ctype_name_mapping, ) diff --git a/pyproject.toml b/pyproject.toml index 93cc57f3..8c1008d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -221,7 +221,6 @@ module = [ "gwcs.converters.*", "gwcs.wcs.*", "gwcs.api.*", - "gwcs.coordinate_frames._coordinate_frames.*", "gwcs.examples.*", "gwcs.extension.*", "gwcs.geometry.*",