diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a0f3981f..e3598fda 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -59,3 +59,13 @@ repos: - id: ruff args: ["--fix", "--show-fixes"] - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: "v1.15.0" + hooks: + - id: mypy + files: gwcs + args: [] + additional_dependencies: # add dependencies for mypy to pull information from + - numpy>=2 + - astropy>=7 diff --git a/CHANGES.rst b/CHANGES.rst index 477d87f9..e31aa0ae 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,6 +3,12 @@ - Fix API issue with ``wcs.numerical_inverse``. [#565] +- Bugfix for ``__call__`` and ``invert`` incorrectly handling units when involving + "parameterless" transforms. [#562] + +- Fix bug where "vector" (shape (n,) not shape (1, n)) arrays would loose all their entries except the + first if ``with_units=True`` was used. [#563] + 0.24.0 (2025-02-04) ------------------- diff --git a/gwcs/_typing.py b/gwcs/_typing.py new file mode 100644 index 00000000..6aed2a9c --- /dev/null +++ b/gwcs/_typing.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from fractions import Fraction +from typing import TypeAlias + +import numpy as np +import numpy.typing as npt +from astropy.coordinates import ( + BaseCoordinateFrame, + SkyCoord, + SpectralCoord, + StokesCoord, +) +from astropy.modeling.bounding_box import CompoundBoundingBox, ModelBoundingBox +from astropy.time import Time +from astropy.units import Quantity + +__all__ = [ + "AxisPhysicalType", + "AxisPhysicalTypes", + "BoundingBox", + "Bounds", + "HighLevelObject", + "HighLevelObjects", + "Interval", + "LowLevelArrays", + "LowLevelUnitArrays", + "LowLevelUnitValue", + "LowLevelValue", + "OutputLowLevelArray", + "Real", +] + +Real: TypeAlias = int | float | Fraction | np.integer | np.floating + +Interval: TypeAlias = tuple[Real, Real] +Bounds: TypeAlias = tuple[Interval, ...] | None + +BoundingBox: TypeAlias = ModelBoundingBox | CompoundBoundingBox | None + +# This is to represent a single value from a low-level function. +LowLevelValue: TypeAlias = Real | npt.NDArray[np.number] +# Handle when units are a possibility. Not all functions allow units in/out +LowLevelUnitValue: TypeAlias = LowLevelValue | Quantity + +# This is to represent all the values together for a single low-level function. +LowLevelArrays: TypeAlias = tuple[LowLevelValue, ...] | LowLevelValue +LowLevelUnitArrays: TypeAlias = tuple[LowLevelUnitValue, ...] + +# This is to represent a general array output from a low-level function. +# Due to the fact 1D outputs are returned as a single value, rather than a tuple. +OutputLowLevelArray: TypeAlias = LowLevelValue | LowLevelArrays + +HighLevelObject: TypeAlias = Time | SkyCoord | SpectralCoord | StokesCoord | Quantity +HighLevelObjects: TypeAlias = tuple[HighLevelObject, ...] | HighLevelObject + +AxisPhysicalType: TypeAlias = str | BaseCoordinateFrame +AxisPhysicalTypes: TypeAlias = tuple[str | BaseCoordinateFrame, ...] diff --git a/gwcs/api.py b/gwcs/api.py index 4a563eda..042b28f3 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -5,13 +5,60 @@ """ +from collections.abc import Callable +from typing import Any, NamedTuple, TypeAlias + import astropy.units as u from astropy.modeling import separable from astropy.wcs.wcsapi import BaseLowLevelWCS, HighLevelWCSMixin from gwcs import utils -__all__ = ["GWCSAPIMixin"] +__all__ = [ + "GWCSAPIMixin", + "WorldAxisClass", + "WorldAxisClasses", + "WorldAxisComponent", + "WorldAxisComponents", + "WorldAxisConverterClass", +] + + +class WorldAxisClass(NamedTuple): + """ + Named tuple for the world_axis_object_classes WCS property + """ + + object_type: type | str + args: tuple[int | None, ...] + kwargs: dict[str, Any] + + +class WorldAxisConverterClass(NamedTuple): + """ + Named tuple for the world_axis_object_classes WCS property, which have a converter + """ + + object_type: type | str + args: tuple[int | None, ...] + kwargs: dict[str, Any] + converter: Callable[..., Any] | None = None + + +WorldAxisClasses: TypeAlias = dict[str | int, WorldAxisClass | WorldAxisConverterClass] + + +class WorldAxisComponent(NamedTuple): + """ + Named tuple for the world_axis_object_components WCS property + """ + + name: str + key: str | int + property_name: str | Callable[[Any], Any] + + +WorldAxisComponents: TypeAlias = list[WorldAxisComponent] class GWCSAPIMixin(BaseLowLevelWCS, HighLevelWCSMixin): diff --git a/gwcs/coordinate_frames.py b/gwcs/coordinate_frames.py deleted file mode 100644 index 62519b58..00000000 --- a/gwcs/coordinate_frames.py +++ /dev/null @@ -1,1157 +0,0 @@ -# Licensed under a 3-clause BSD style license - see LICENSE.rst -""" -This module defines coordinate frames for describing the inputs and/or outputs -of a transform. - -In the block diagram, the WCS pipeline has a two stage transformation (two -astropy Model instances), with an input frame, an output frame and an -intermediate frame. - -.. code-block:: - - ┌───────────────┐ - │ │ - │ Input │ - │ Frame │ - │ │ - └───────┬───────┘ - │ - ┌─────▼─────┐ - │ Transform │ - └─────┬─────┘ - │ - ┌───────▼───────┐ - │ │ - │ Intermediate │ - │ Frame │ - │ │ - └───────┬───────┘ - │ - ┌─────▼─────┐ - │ Transform │ - └─────┬─────┘ - │ - ┌───────▼───────┐ - │ │ - │ Output │ - │ Frame │ - │ │ - └───────────────┘ - - -Each frame instance is both metadata for the inputs/outputs of a transform and -also a converter between those inputs/outputs and richer coordinate -representations of those inputs/outputs. - -For example, an output frame of type `~gwcs.coordinate_frames.SpectralFrame` -provides metadata to the `.WCS` object such as the ``axes_type`` being -``"SPECTRAL"`` and the unit of the output etc. The output frame also provides a -converter of the numeric output of the transform to a -`~astropy.coordinates.SpectralCoord` object, by combining this metadata with the -numerical values. - -``axes_order`` and conversion between objects and arguments -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -One of the key concepts regarding coordinate frames is the ``axes_order`` argument. -This argument is used to map from the components of the frame to the inputs/outputs -of the transform. To illustrate this consider this situation where you have a -forward transform which outputs three coordinates ``[lat, lambda, lon]``. These -would be represented as a `.SpectralFrame` and a `.CelestialFrame`, however, the -axes of a `.CelestialFrame` are always ``[lon, lat]``, so by specifying two -frames as - -.. code-block:: python - - [SpectralFrame(axes_order=(1,)), CelestialFrame(axes_order=(2, 0))] - -we would map the outputs of this transform into the correct positions in the frames. - As shown below, this is also used when constructing the inputs to the inverse - transform. - - -When taking the output from the forward transform the following transformation -is performed by the coordinate frames: - -.. code-block:: - - lat, lambda, lon - │ │ │ - └──────┼─────┼────────┐ - ┌───────────┘ └──┐ │ - │ │ │ - ┌─────────▼────────┐ ┌──────▼─────▼─────┐ - │ │ │ │ - │ SpectralFrame │ │ CelestialFrame │ - │ │ │ │ - │ (1,) │ │ (2, 0) │ - │ │ │ │ - └─────────┬────────┘ └──────────┬────┬──┘ - │ │ │ - │ │ │ - ▼ ▼ ▼ - SpectralCoord(lambda) SkyCoord((lon, lat)) - - -When considering the backward transform the following transformations take place -in the coordinate frames before the transform is called: - -.. code-block:: - - SpectralCoord(lambda) SkyCoord((lon, lat)) - │ │ │ - └─────┐ ┌────────────┘ │ - │ │ ┌────────────┘ - ▼ ▼ ▼ - [lambda, lon, lat] - │ │ │ - │ │ │ - ┌──────▼─────▼────▼────┐ - │ │ - │ Sort by axes_order │ - │ │ - └────┬──────┬─────┬────┘ - │ │ │ - ▼ ▼ ▼ - lat, lambda, lon - -""" - -import abc -import contextlib -import logging -import numbers -from collections import defaultdict -from dataclasses import InitVar, dataclass - -import numpy as np -from astropy import coordinates as coord -from astropy import time -from astropy import units as u -from astropy import utils as astutil -from astropy.coordinates import StokesCoord -from astropy.utils.misc import isiterable -from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1 -from astropy.wcs.wcsapi.high_level_api import ( - high_level_objects_to_values, - values_to_high_level_objects, -) -from astropy.wcs.wcsapi.low_level_api import VALID_UCDS, validate_physical_types - -__all__ = [ - "BaseCoordinateFrame", - "CelestialFrame", - "CompositeFrame", - "CoordinateFrame", - "EmptyFrame", - "Frame2D", - "SpectralFrame", - "StokesFrame", - "TemporalFrame", -] - - -def _ucd1_to_ctype_name_mapping(ctype_to_ucd, allowed_ucd_duplicates): - inv_map = {} - new_ucd = set() - - for kwd, ucd in ctype_to_ucd.items(): - if ucd in inv_map: - if ucd not in allowed_ucd_duplicates: - new_ucd.add(ucd) - continue - inv_map[ucd] = allowed_ucd_duplicates.get(ucd, kwd) - - if new_ucd: - logging.warning( - "Found unsupported duplicate physical type in 'astropy' mapping to CTYPE.\n" - "Update 'gwcs' to the latest version or notify 'gwcs' developer.\n" - "Duplicate physical types will be mapped to the following CTYPEs:\n" - + "\n".join([f"{ucd!r:s} --> {inv_map[ucd]!r:s}" for ucd in new_ucd]) - ) - - return inv_map - - -# List below allowed physical type duplicates and a corresponding CTYPE -# to which all duplicates will be mapped to: -_ALLOWED_UCD_DUPLICATES = { - "time": "TIME", - "em.wl": "WAVE", -} - -UCD1_TO_CTYPE = _ucd1_to_ctype_name_mapping( - ctype_to_ucd=CTYPE_TO_UCD1, allowed_ucd_duplicates=_ALLOWED_UCD_DUPLICATES -) - -STANDARD_REFERENCE_FRAMES = [frame.upper() for frame in coord.builtin_frames.__all__] - - -def get_ctype_from_ucd(ucd): - """ - Return the FITS ``CTYPE`` corresponding to a UCD1 value. - - Parameters - ---------- - ucd : str - UCD string, for example one of ```WCS.world_axis_physical_types``. - - Returns - ------- - CTYPE : str - The corresponding FITS ``CTYPE`` value or an empty string. - """ - return UCD1_TO_CTYPE.get(ucd, "") - - -@dataclass -class FrameProperties: - naxes: InitVar[int] - axes_type: tuple[str] - unit: tuple[u.Unit] = None - axes_names: tuple[str] = None - axis_physical_types: list[str] = None - - def __post_init__(self, naxes): - if isinstance(self.axes_type, str): - self.axes_type = (self.axes_type,) - else: - self.axes_type = tuple(self.axes_type) - - if len(self.axes_type) != naxes: - msg = "Length of axes_type does not match number of axes." - raise ValueError(msg) - - if self.unit is not None: - unit = tuple(self.unit) if astutil.isiterable(self.unit) else (self.unit,) - if len(unit) != naxes: - msg = "Number of units does not match number of axes." - raise ValueError(msg) - self.unit = tuple(u.Unit(au) for au in unit) - else: - self.unit = tuple(u.dimensionless_unscaled for na in range(naxes)) - - if self.axes_names is not None: - if isinstance(self.axes_names, str): - self.axes_names = (self.axes_names,) - else: - self.axes_names = tuple(self.axes_names) - if len(self.axes_names) != naxes: - msg = "Number of axes names does not match number of axes." - raise ValueError(msg) - else: - self.axes_names = tuple([""] * naxes) - - if self.axis_physical_types is not None: - if isinstance(self.axis_physical_types, str): - self.axis_physical_types = (self.axis_physical_types,) - elif not isiterable(self.axis_physical_types): - msg = ( - "axis_physical_types must be of type string or iterable of strings" - ) - raise TypeError(msg) - if len(self.axis_physical_types) != naxes: - msg = f'"axis_physical_types" must be of length {naxes}' - raise ValueError(msg) - ph_type = [] - for axt in self.axis_physical_types: - if axt not in VALID_UCDS and not axt.startswith("custom:"): - ph_type.append(f"custom:{axt}") - else: - ph_type.append(axt) - - validate_physical_types(ph_type) - self.axis_physical_types = tuple(ph_type) - - @property - def _default_axis_physical_types(self): - """ - The default physical types to use for this frame if none are specified - by the user. - """ - return tuple(f"custom:{t}" for t in self.axes_type) - - -class BaseCoordinateFrame(abc.ABC): - """ - API Definition for a Coordinate frame - """ - - _prop: FrameProperties - """ - The FrameProperties object holding properties in native frame order. - """ - - @property - @abc.abstractmethod - def naxes(self) -> int: - """ - The number of axes described by this frame. - """ - - @property - @abc.abstractmethod - def name(self) -> str: - """ - The name of the coordinate frame. - """ - - @property - @abc.abstractmethod - def unit(self) -> tuple[u.Unit, ...]: - """ - The units of the axes in this frame. - """ - - @property - @abc.abstractmethod - def axes_names(self) -> tuple[str, ...]: - """ - Names describing the axes of the frame. - """ - - @property - @abc.abstractmethod - def axes_order(self) -> tuple[int, ...]: - """ - The position of the axes in the frame in the transform. - """ - - @property - @abc.abstractmethod - def reference_frame(self): - """ - The reference frame of the coordinates described by this frame. - - This is usually an Astropy object such as ``SkyCoord`` or ``Time``. - """ - - @property - @abc.abstractmethod - def axes_type(self): - """ - An upcase string describing the type of the axis. - - Known values are ``"SPATIAL", "TEMPORAL", "STOKES", "SPECTRAL", "PIXEL"``. - """ - - @property - @abc.abstractmethod - def axis_physical_types(self): - """ - The UCD 1+ physical types for the axes, in frame order. - """ - - @property - @abc.abstractmethod - def world_axis_object_classes(self): - """ - The APE 14 object classes for this frame. - - See Also - -------- - astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_classes - """ - - @property - def world_axis_object_components(self): - """ - The APE 14 object components for this frame. - - See Also - -------- - astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_components - """ - if self.naxes == 1: - return self._native_world_axis_object_components - - # If we have more than one axis then we should sort the native - # components by the axes_order. - ordered = np.array(self._native_world_axis_object_components, dtype=object)[ - np.argsort(self.axes_order) - ] - return list(map(tuple, ordered)) - - @property - @abc.abstractmethod - def _native_world_axis_object_components(self): - """ - This property holds the "native" frame order of the components. - - The native order of the components is the order the frame assumes the - axes are in when creating the high level objects, for example - ``CelestialFrame`` creates ``SkyCoord`` objects which are in lon, lat - order (in their positional args). - - This property is used both to construct the ordered - ``world_axis_object_components`` property as well as by `CompositeFrame` - to be able to get the components in their native order. - """ - - -class CoordinateFrame(BaseCoordinateFrame): - """ - Base class for Coordinate Frames. - - Parameters - ---------- - naxes : int - Number of axes. - axes_type : str - One of ["SPATIAL", "SPECTRAL", "TIME"] - axes_order : tuple of int - A dimension in the input data that corresponds to this axis. - reference_frame : astropy.coordinates.builtin_frames - Reference frame (usually used with output_frame to convert to world - coordinate objects). - unit : list of astropy.units.Unit - Unit for each axis. - axes_names : list - Names of the axes in this frame. - name : str - Name of this frame. - """ - - def __init__( - self, - naxes, - axes_type, - axes_order, - reference_frame=None, - unit=None, - axes_names=None, - name=None, - axis_physical_types=None, - ): - self._naxes = naxes - self._axes_order = tuple(axes_order) - self._reference_frame = reference_frame - - if name is None: - self._name = self.__class__.__name__ - else: - self._name = name - - if len(self._axes_order) != naxes: - msg = "Length of axes_order does not match number of axes." - raise ValueError(msg) - - if isinstance(axes_type, str): - axes_type = (axes_type,) - - self._prop = FrameProperties( - naxes, - axes_type, - unit, - axes_names, - axis_physical_types or self._default_axis_physical_types(axes_type), - ) - - super().__init__() - - def _default_axis_physical_types(self, axes_type): - """ - The default physical types to use for this frame if none are specified - by the user. - """ - return tuple(f"custom:{t}" for t in axes_type) - - def __repr__(self): - fmt = ( - f'<{self.__class__.__name__}(name="{self.name}", unit={self.unit}, ' - f"axes_names={self.axes_names}, axes_order={self.axes_order}" - ) - if self.reference_frame is not None: - fmt += f", reference_frame={self.reference_frame}" - fmt += ")>" - return fmt - - def __str__(self): - if self._name is not None: - return self._name - return self.__class__.__name__ - - def _sort_property(self, prop): - sorted_prop = sorted( - zip(prop, self.axes_order, strict=False), key=lambda x: x[1] - ) - return tuple([t[0] for t in sorted_prop]) - - @property - def name(self): - """A custom name of this frame.""" - return self._name - - @name.setter - def name(self, val): - """A custom name of this frame.""" - self._name = val - - @property - def naxes(self): - """The number of axes in this frame.""" - return self._naxes - - @property - def unit(self): - """The unit of this frame.""" - return self._sort_property(self._prop.unit) - - @property - def axes_names(self): - """Names of axes in the frame.""" - return self._sort_property(self._prop.axes_names) - - @property - def axes_order(self): - """A tuple of indices which map inputs to axes.""" - return self._axes_order - - @property - def reference_frame(self): - """Reference frame, used to convert to world coordinate objects.""" - return self._reference_frame - - @property - def axes_type(self): - """Type of this frame : 'SPATIAL', 'SPECTRAL', 'TIME'.""" - return self._sort_property(self._prop.axes_type) - - @property - def axis_physical_types(self): - """ - The axis physical types for this frame. - - These physical types are the types in frame order, not transform order. - """ - return self._sort_property(self._prop.axis_physical_types) - - @property - def world_axis_object_classes(self): - return { - f"{at}{i}" if i != 0 else at: (u.Quantity, (), {"unit": unit}) - for i, (at, unit) in enumerate(zip(self.axes_type, self.unit, strict=False)) - } - - @property - def _native_world_axis_object_components(self): - return [ - (f"{at}{i}" if i != 0 else at, 0, "value") - for i, at in enumerate(self._prop.axes_type) - ] - - @property - def serialized_classes(self): - """ - This property is used by the low level WCS API in Astropy. - - By providing it we can duck type as a low level WCS object. - """ - return False - - def to_high_level_coordinates(self, *values): - """ - Convert "values" to high level coordinate objects described by this frame. - - "values" are the coordinates in array or scalar form, and high level - objects are things such as ``SkyCoord`` or ``Quantity``. See - :ref:`wcsapi` for details. - - Parameters - ---------- - values : `numbers.Number`, `numpy.ndarray`, or `~astropy.units.Quantity` - ``naxis`` number of coordinates as scalars or arrays. - - Returns - ------- - high_level_coordinates - One (or more) high level object describing the coordinate. - """ - # We allow Quantity-like objects here which values_to_high_level_objects - # does not. - values = [ - v.to_value(unit) if hasattr(v, "to_value") else v - for v, unit in zip(values, self.unit, strict=False) - ] - - if not all( - isinstance(v, numbers.Number) or type(v) is np.ndarray for v in values - ): - msg = "All values should be a scalar number or a numpy array." - raise TypeError(msg) - - high_level = values_to_high_level_objects(*values, low_level_wcs=self) - if len(high_level) == 1: - high_level = high_level[0] - return high_level - - def from_high_level_coordinates(self, *high_level_coords): - """ - Convert high level coordinate objects to "values" as described by this frame. - - "values" are the coordinates in array or scalar form, and high level - objects are things such as ``SkyCoord`` or ``Quantity``. See - :ref:`wcsapi` for details. - - Parameters - ---------- - high_level_coordinates - One (or more) high level object describing the coordinate. - - Returns - ------- - values : `numbers.Number` or `numpy.ndarray` - ``naxis`` number of coordinates as scalars or arrays. - """ - values = high_level_objects_to_values(*high_level_coords, low_level_wcs=self) - if len(values) == 1: - values = values[0] - return values - - -class EmptyFrame(CoordinateFrame): - """ - Represents a "default" detector frame. This is for use as the default value - for input frame by the WCS object. - """ - - def __init__(self, name=None): - self._name = "detector" if name is None else name - - def __repr__(self): - return f'<{type(self).__name__}(name="{self.name}")>' - - def __str__(self): - if self._name is not None: - return self._name - return type(self).__name__ - - @property - def name(self): - """A custom name of this frame.""" - return self._name - - @name.setter - def name(self, val): - """A custom name of this frame.""" - self._name = val - - def _raise_error(self) -> None: - msg = "EmptyFrame does not have any information" - raise NotImplementedError(msg) - - @property - def naxes(self): - self._raise_error() - - @property - def unit(self): - self._raise_error() - - @property - def axes_names(self): - self._raise_error() - - @property - def axes_order(self): - self._raise_error() - - @property - def reference_frame(self): - self._raise_error() - - @property - def axes_type(self): - self._raise_error() - - @property - def axis_physical_types(self): - self._raise_error() - - @property - def world_axis_object_classes(self): - self._raise_error() - - @property - def _native_world_axis_object_components(self): - self._raise_error() - - def to_high_level_coordinates(self, *values): - self._raise_error() - - def from_high_level_coordinates(self, *high_level_coords): - self._raise_error() - - -class CelestialFrame(CoordinateFrame): - """ - Representation of a Celesital coordinate system. - - This class has a native order of longitude then latitude, meaning - ``axes_names``, ``unit`` and ``axis_physical_types`` should be lon, lat - ordered. If your transform is in a different order this should be specified - with ``axes_order``. - - Parameters - ---------- - axes_order : tuple of int - A dimension in the input data that corresponds to this axis. - reference_frame : astropy.coordinates.builtin_frames - A reference frame. - unit : str or units.Unit instance or iterable of those - Units on axes. - axes_names : list - Names of the axes in this frame. - name : str - Name of this frame. - axis_physical_types : list - The UCD 1+ physical types for the axes, in frame order (lon, lat). - """ - - def __init__( - self, - axes_order=None, - reference_frame=None, - unit=None, - axes_names=None, - name=None, - axis_physical_types=None, - ): - naxes = 2 - if ( - reference_frame is not None - and not isinstance(reference_frame, str) - and reference_frame.name.upper() in STANDARD_REFERENCE_FRAMES - ): - _axes_names = list(reference_frame.representation_component_names.values()) - if "distance" in _axes_names: - _axes_names.remove("distance") - if axes_names is None: - axes_names = _axes_names - naxes = len(_axes_names) - - self.native_axes_order = tuple(range(naxes)) - if axes_order is None: - axes_order = self.native_axes_order - if unit is None: - unit = tuple([u.degree] * naxes) - axes_type = ["SPATIAL"] * naxes - - pht = axis_physical_types or self._default_axis_physical_types( - reference_frame, axes_names - ) - super().__init__( - naxes=naxes, - axes_type=axes_type, - axes_order=axes_order, - reference_frame=reference_frame, - unit=unit, - axes_names=axes_names, - name=name, - axis_physical_types=pht, - ) - - def _default_axis_physical_types(self, reference_frame, axes_names): - if isinstance(reference_frame, coord.Galactic): - return "pos.galactic.lon", "pos.galactic.lat" - if isinstance( - reference_frame, - coord.GeocentricTrueEcliptic | coord.GCRS | coord.PrecessedGeocentric, - ): - return "pos.bodyrc.lon", "pos.bodyrc.lat" - if isinstance(reference_frame, coord.builtin_frames.BaseRADecFrame): - return "pos.eq.ra", "pos.eq.dec" - if isinstance(reference_frame, coord.builtin_frames.BaseEclipticFrame): - return "pos.ecliptic.lon", "pos.ecliptic.lat" - return tuple(f"custom:{t}" for t in axes_names) - - @property - def world_axis_object_classes(self): - return { - "celestial": ( - coord.SkyCoord, - (), - {"frame": self.reference_frame, "unit": self._prop.unit}, - ) - } - - @property - def _native_world_axis_object_components(self): - return [ - ("celestial", 0, lambda sc: sc.spherical.lon.to_value(self._prop.unit[0])), - ("celestial", 1, lambda sc: sc.spherical.lat.to_value(self._prop.unit[1])), - ] - - -class SpectralFrame(CoordinateFrame): - """ - Represents Spectral Frame - - Parameters - ---------- - axes_order : tuple or int - A dimension in the input data that corresponds to this axis. - reference_frame : astropy.coordinates.builtin_frames - Reference frame (usually used with output_frame to convert to world - coordinate objects). - unit : str or units.Unit instance - Spectral unit. - axes_names : str - Spectral axis name. - name : str - Name for this frame. - - """ - - def __init__( - self, - axes_order=(0,), - reference_frame=None, - unit=None, - axes_names=None, - name=None, - axis_physical_types=None, - ): - if not isiterable(unit): - unit = (unit,) - unit = [u.Unit(un) for un in unit] - pht = axis_physical_types or self._default_axis_physical_types(unit) - - super().__init__( - naxes=1, - axes_type="SPECTRAL", - axes_order=axes_order, - axes_names=axes_names, - reference_frame=reference_frame, - unit=unit, - name=name, - axis_physical_types=pht, - ) - - def _default_axis_physical_types(self, unit): - if unit[0].physical_type == "frequency": - return ("em.freq",) - if unit[0].physical_type == "length": - return ("em.wl",) - if unit[0].physical_type == "energy": - return ("em.energy",) - if unit[0].physical_type == "speed": - return ("spect.dopplerVeloc",) - logging.warning( - "Physical type may be ambiguous. Consider " - "setting the physical type explicitly as " - "either 'spect.dopplerVeloc.optical' or " - "'spect.dopplerVeloc.radio'." - ) - return (f"custom:{unit[0].physical_type}",) - - @property - def world_axis_object_classes(self): - return {"spectral": (coord.SpectralCoord, (), {"unit": self.unit[0]})} - - @property - def _native_world_axis_object_components(self): - return [("spectral", 0, lambda sc: sc.to_value(self.unit[0]))] - - -class TemporalFrame(CoordinateFrame): - """ - A coordinate frame for time axes. - - Parameters - ---------- - reference_frame : `~astropy.time.Time` - A Time object which holds the time scale and format. - If data is provided, it is the time zero point. - To not set a zero point for the frame initialize ``reference_frame`` - with an empty list. - unit : str or `~astropy.units.Unit` - Time unit. - axes_names : str - Time axis name. - axes_order : tuple or int - A dimension in the data that corresponds to this axis. - name : str - Name for this frame. - """ - - def __init__( - self, - reference_frame, - unit=u.s, - axes_order=(0,), - axes_names=None, - name=None, - axis_physical_types=None, - ): - axes_names = ( - axes_names - or f"{reference_frame.format}({reference_frame.scale}; " - f"{reference_frame.location}" - ) - - pht = axis_physical_types or self._default_axis_physical_types() - - super().__init__( - naxes=1, - axes_type="TIME", - axes_order=axes_order, - axes_names=axes_names, - reference_frame=reference_frame, - unit=unit, - name=name, - axis_physical_types=pht, - ) - self._attrs = {} - for a in self.reference_frame.info._represent_as_dict_extra_attrs: - with contextlib.suppress(AttributeError): - self._attrs[a] = getattr(self.reference_frame, a) - - def _default_axis_physical_types(self): - return ("time",) - - def _convert_to_time(self, dt, *, unit, **kwargs): - if ( - not isinstance(dt, time.TimeDelta) and isinstance(dt, time.Time) - ) or isinstance(self.reference_frame.value, np.ndarray): - return time.Time(dt, **kwargs) - - if not hasattr(dt, "unit"): - dt = dt * unit - - return self.reference_frame + dt - - @property - def world_axis_object_classes(self): - comp = ( - time.Time, - (), - {"unit": self.unit[0], **self._attrs}, - self._convert_to_time, - ) - - return {"temporal": comp} - - @property - def _native_world_axis_object_components(self): - if isinstance(self.reference_frame.value, np.ndarray): - return [("temporal", 0, "value")] - - def offset_from_time_and_reference(time): - return (time - self.reference_frame).sec - - return [("temporal", 0, offset_from_time_and_reference)] - - -class CompositeFrame(CoordinateFrame): - """ - Represents one or more frames. - - Parameters - ---------- - frames : list - List of constituient frames. - name : str - Name for this frame. - """ - - def __init__(self, frames, name=None): - self._frames = frames[:] - naxes = sum([frame._naxes for frame in self._frames]) - - axes_order = [] - axes_type = [] - axes_names = [] - unit = [] - ph_type = [] - - for frame in frames: - axes_order.extend(frame.axes_order) - - # Stack the raw (not-native) ordered properties - for frame in frames: - axes_type += list(frame._prop.axes_type) - axes_names += list(frame._prop.axes_names) - unit += list(frame._prop.unit) - ph_type += list(frame._prop.axis_physical_types) - - if len(np.unique(axes_order)) != len(axes_order): - msg = ( - "Incorrect numbering of axes, " - "axes_order should contain unique numbers, " - f"got {axes_order}." - ) - raise ValueError(msg) - - super().__init__( - naxes, - axes_type=axes_type, - axes_order=axes_order, - unit=unit, - axes_names=axes_names, - axis_physical_types=tuple(ph_type), - name=name, - ) - self._axis_physical_types = tuple(ph_type) - - @property - def frames(self): - """ - The constituient frames that comprise this `CompositeFrame`. - """ - return self._frames - - def __repr__(self): - return repr(self.frames) - - @property - def _wao_classes_rename_map(self): - mapper = defaultdict(dict) - seen_names = [] - for frame in self.frames: - # ensure the frame is in the mapper - mapper[frame] - for key in frame.world_axis_object_classes: - if key in seen_names: - new_key = f"{key}{seen_names.count(key)}" - mapper[frame][key] = new_key - seen_names.append(key) - return mapper - - @property - def _wao_renamed_components_iter(self): - mapper = self._wao_classes_rename_map - for frame in self.frames: - renamed_components = [] - for component in frame._native_world_axis_object_components: - comp = list(component) - rename = mapper[frame].get(comp[0]) - if rename: - comp[0] = rename - renamed_components.append(tuple(comp)) - yield frame, renamed_components - - @property - def _wao_renamed_classes_iter(self): - mapper = self._wao_classes_rename_map - for frame in self.frames: - for key, value in frame.world_axis_object_classes.items(): - rename = mapper[frame].get(key) - yield rename if rename else key, value - - @property - def world_axis_object_components(self): - out = [None] * self.naxes - - for frame, components in self._wao_renamed_components_iter: - for i, ao in enumerate(frame.axes_order): - out[ao] = components[i] - - if any(o is None for o in out): - msg = "axes_order leads to incomplete world_axis_object_components" - raise ValueError(msg) - - return out - - @property - def world_axis_object_classes(self): - return dict(self._wao_renamed_classes_iter) - - -class StokesFrame(CoordinateFrame): - """ - A coordinate frame for representing Stokes polarisation states. - - Parameters - ---------- - name : str - Name of this frame. - axes_order : tuple - A dimension in the data that corresponds to this axis. - """ - - def __init__( - self, - axes_order=(0,), - axes_names=("stokes",), - name=None, - axis_physical_types=None, - ): - pht = axis_physical_types or self._default_axis_physical_types() - - super().__init__( - 1, - ["STOKES"], - axes_order, - name=name, - axes_names=axes_names, - unit=u.one, - axis_physical_types=pht, - ) - - def _default_axis_physical_types(self): - return ("phys.polarization.stokes",) - - @property - def world_axis_object_classes(self): - return { - "stokes": ( - StokesCoord, - (), - {}, - ) - } - - @property - def _native_world_axis_object_components(self): - return [("stokes", 0, "value")] - - -class Frame2D(CoordinateFrame): - """ - A 2D coordinate frame. - - Parameters - ---------- - axes_order : tuple of int - A dimension in the input data that corresponds to this axis. - unit : list of astropy.units.Unit - Unit for each axis. - axes_names : list - Names of the axes in this frame. - name : str - Name of this frame. - """ - - def __init__( - self, - axes_order=(0, 1), - unit=(u.pix, u.pix), - axes_names=("x", "y"), - name=None, - axes_type=None, - axis_physical_types=None, - ): - if axes_type is None: - axes_type = ["SPATIAL", "SPATIAL"] - pht = axis_physical_types or self._default_axis_physical_types( - axes_names, axes_type - ) - - super().__init__( - naxes=2, - axes_type=axes_type, - axes_order=axes_order, - name=name, - axes_names=axes_names, - unit=unit, - axis_physical_types=pht, - ) - - def _default_axis_physical_types(self, axes_names, axes_type): - if axes_names is not None and all(axes_names): - ph_type = axes_names - else: - ph_type = axes_type - - return tuple(f"custom:{t}" for t in ph_type) diff --git a/gwcs/coordinate_frames/__init__.py b/gwcs/coordinate_frames/__init__.py new file mode 100644 index 00000000..a53b6412 --- /dev/null +++ b/gwcs/coordinate_frames/__init__.py @@ -0,0 +1,143 @@ +""" +This module defines coordinate frames for describing the inputs and/or outputs +of a transform. + +In the block diagram, the WCS pipeline has a two stage transformation (two +astropy Model instances), with an input frame, an output frame and an +intermediate frame. + +.. code-block:: + + ┌───────────────┐ + │ │ + │ Input │ + │ Frame │ + │ │ + └───────┬───────┘ + │ + ┌─────▼─────┐ + │ Transform │ + └─────┬─────┘ + │ + ┌───────▼───────┐ + │ │ + │ Intermediate │ + │ Frame │ + │ │ + └───────┬───────┘ + │ + ┌─────▼─────┐ + │ Transform │ + └─────┬─────┘ + │ + ┌───────▼───────┐ + │ │ + │ Output │ + │ Frame │ + │ │ + └───────────────┘ + + +Each frame instance is both metadata for the inputs/outputs of a transform and +also a converter between those inputs/outputs and richer coordinate +representations of those inputs/outputs. + +For example, an output frame of type `~gwcs.coordinate_frames.SpectralFrame` +provides metadata to the `.WCS` object such as the ``axes_type`` being +``"SPECTRAL"`` and the unit of the output etc. The output frame also provides a +converter of the numeric output of the transform to a +`~astropy.coordinates.SpectralCoord` object, by combining this metadata with the +numerical values. + +``axes_order`` and conversion between objects and arguments +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +One of the key concepts regarding coordinate frames is the ``axes_order`` argument. +This argument is used to map from the components of the frame to the inputs/outputs +of the transform. To illustrate this consider this situation where you have a +forward transform which outputs three coordinates ``[lat, lambda, lon]``. These +would be represented as a `.SpectralFrame` and a `.CelestialFrame`, however, the +axes of a `.CelestialFrame` are always ``[lon, lat]``, so by specifying two +frames as + +.. code-block:: python + + [SpectralFrame(axes_order=(1,)), CelestialFrame(axes_order=(2, 0))] + +we would map the outputs of this transform into the correct positions in the frames. + As shown below, this is also used when constructing the inputs to the inverse + transform. + + +When taking the output from the forward transform the following transformation +is performed by the coordinate frames: + +.. code-block:: + + lat, lambda, lon + │ │ │ + └──────┼─────┼────────┐ + ┌───────────┘ └──┐ │ + │ │ │ + ┌─────────▼────────┐ ┌──────▼─────▼─────┐ + │ │ │ │ + │ SpectralFrame │ │ CelestialFrame │ + │ │ │ │ + │ (1,) │ │ (2, 0) │ + │ │ │ │ + └─────────┬────────┘ └──────────┬────┬──┘ + │ │ │ + │ │ │ + ▼ ▼ ▼ + SpectralCoord(lambda) SkyCoord((lon, lat)) + + +When considering the backward transform the following transformations take place +in the coordinate frames before the transform is called: + +.. code-block:: + + SpectralCoord(lambda) SkyCoord((lon, lat)) + │ │ │ + └─────┐ ┌────────────┘ │ + │ │ ┌────────────┘ + ▼ ▼ ▼ + [lambda, lon, lat] + │ │ │ + │ │ │ + ┌──────▼─────▼────▼────┐ + │ │ + │ Sort by axes_order │ + │ │ + └────┬──────┬─────┬────┘ + │ │ │ + ▼ ▼ ▼ + lat, lambda, lon + +""" + +from ._axis import AxisType +from ._base import BaseCoordinateFrame +from ._celestial import CelestialFrame +from ._composite import CompositeFrame +from ._core import CoordinateFrame +from ._empty import EmptyFrame +from ._frame import Frame2D +from ._spectral import SpectralFrame +from ._stokes import StokesFrame +from ._temporal import TemporalFrame +from ._utils import get_ctype_from_ucd + +__all__ = [ + "AxisType", + "BaseCoordinateFrame", + "CelestialFrame", + "CompositeFrame", + "CoordinateFrame", + "EmptyFrame", + "Frame2D", + "SpectralFrame", + "StokesFrame", + "TemporalFrame", + "get_ctype_from_ucd", +] diff --git a/gwcs/coordinate_frames/_axis.py b/gwcs/coordinate_frames/_axis.py new file mode 100644 index 00000000..fbaaa33b --- /dev/null +++ b/gwcs/coordinate_frames/_axis.py @@ -0,0 +1,24 @@ +from enum import StrEnum +from typing import TypeAlias + +__all__ = ["AxesType", "AxisType"] + + +class AxisType(StrEnum): + """ + Enumeration of the Axis types + """ + + SPATIAL = "SPATIAL" + SPECTRAL = "SPECTRAL" + TIME = "TIME" + STOKES = "STOKES" + + def __str__(self) -> str: + return self.value + + def __repr__(self) -> str: + return self.value + + +AxesType: TypeAlias = tuple[AxisType | str, ...] | AxisType | str diff --git a/gwcs/coordinate_frames/_base.py b/gwcs/coordinate_frames/_base.py new file mode 100644 index 00000000..41af62ef --- /dev/null +++ b/gwcs/coordinate_frames/_base.py @@ -0,0 +1,173 @@ +import abc +from collections.abc import Sequence +from typing import Any, cast + +import numpy as np +import numpy.typing as npt +from astropy import units as u +from astropy.coordinates import BaseCoordinateFrame as _BaseCoordinateFrame +from astropy.time import Time + +from gwcs._typing import AxisPhysicalTypes, LowLevelUnitValue +from gwcs.api import WorldAxisClasses, WorldAxisComponent, WorldAxisComponents + +from ._axis import AxesType +from ._properties import FrameProperties + +__all__ = ["BaseCoordinateFrame"] + + +class BaseCoordinateFrame(abc.ABC): + """ + API Definition for a Coordinate frame + """ + + _prop: FrameProperties + """ + The FrameProperties object holding properties in native frame order. + """ + + @property + @abc.abstractmethod + def naxes(self) -> int: + """ + The number of axes described by this frame. + """ + + @property + @abc.abstractmethod + def name(self) -> str: + """ + The name of the coordinate frame. + """ + + @property + @abc.abstractmethod + def unit(self) -> tuple[u.Unit, ...]: + """ + The units of the axes in this frame. + """ + + @property + @abc.abstractmethod + def axes_names(self) -> tuple[str, ...]: + """ + Names describing the axes of the frame. + """ + + @property + @abc.abstractmethod + def axes_order(self) -> tuple[int, ...]: + """ + The position of the axes in the frame in the transform. + """ + + @property + @abc.abstractmethod + def reference_frame(self) -> _BaseCoordinateFrame | Time | None: + """ + The reference frame of the coordinates described by this frame. + + This is usually an Astropy object such as ``SkyCoord`` or ``Time``. + """ + + @property + @abc.abstractmethod + def axes_type(self) -> AxesType: + """ + An upcase string describing the type of the axis. + + Known values are ``"SPATIAL", "TEMPORAL", "STOKES", "SPECTRAL", "PIXEL"``. + """ + + @property + @abc.abstractmethod + def axis_physical_types(self) -> AxisPhysicalTypes: + """ + The UCD 1+ physical types for the axes, in frame order. + """ + + @property + @abc.abstractmethod + def world_axis_object_classes(self) -> WorldAxisClasses: + """ + The APE 14 object classes for this frame. + + See Also + -------- + astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_classes + """ + + @property + def world_axis_object_components(self) -> WorldAxisComponents: + """ + The APE 14 object components for this frame. + + See Also + -------- + astropy.wcs.wcsapi.BaseLowLevelWCS.world_axis_object_components + """ + if self.naxes == 1: + return self._native_world_axis_object_components + + # If we have more than one axis then we should sort the native + # components by the axes_order. + ordered = np.array(self._native_world_axis_object_components, dtype=object)[ + np.argsort(self.axes_order) + ].tolist() + + # Unpack the arguments passed by the map into the WorldAxisComponent NamedTuple + # NamedTuple apparently does not automatically unpack the arguments like a tuple + # will + def _func(arg: Sequence[Any]) -> WorldAxisComponent: + return WorldAxisComponent(*arg) + + return list(map(_func, ordered)) + + @property + @abc.abstractmethod + def _native_world_axis_object_components(self) -> WorldAxisComponents: + """ + This property holds the "native" frame order of the components. + + The native order of the components is the order the frame assumes the + axes are in when creating the high level objects, for example + ``CelestialFrame`` creates ``SkyCoord`` objects which are in lon, lat + order (in their positional args). + + This property is used both to construct the ordered + ``world_axis_object_components`` property as well as by `CompositeFrame` + to be able to get the components in their native order. + """ + + def add_units( + self, arrays: tuple[LowLevelUnitValue, ...] + ) -> tuple[u.Quantity, ...]: + """ + Add units to the arrays + """ + return tuple( + u.Quantity(array, unit=unit) # type: ignore[arg-type] + for array, unit in zip(arrays, self.unit, strict=True) + ) + + def remove_units( + self, arrays: tuple[LowLevelUnitValue, ...] | LowLevelUnitValue + ) -> tuple[npt.NDArray[np.number], ...]: + """ + Remove units from the input arrays + """ + if self.naxes == 1: + arrays = (cast(LowLevelUnitValue, arrays),) + + return tuple( + cast( + npt.NDArray[np.number], + array.to_value(unit) # type: ignore[no-untyped-call] + if isinstance(array, u.Quantity) + else array, + ) + for array, unit in zip( + cast(tuple[LowLevelUnitValue, ...], arrays), self.unit, strict=True + ) + ) diff --git a/gwcs/coordinate_frames/_celestial.py b/gwcs/coordinate_frames/_celestial.py new file mode 100644 index 00000000..f68383bc --- /dev/null +++ b/gwcs/coordinate_frames/_celestial.py @@ -0,0 +1,134 @@ +from astropy import units as u +from astropy.coordinates import ( + GCRS, + BaseCoordinateFrame, + Galactic, + GeocentricTrueEcliptic, + PrecessedGeocentric, + SkyCoord, + builtin_frames, +) + +from gwcs._typing import AxisPhysicalTypes +from gwcs.api import ( + WorldAxisClass, + WorldAxisClasses, + WorldAxisComponent, + WorldAxisComponents, +) + +from ._axis import AxisType +from ._core import CoordinateFrame +from ._properties import FrameProperties + +__all__ = ["CelestialFrame"] + +STANDARD_REFERENCE_FRAMES = [frame.upper() for frame in builtin_frames.__all__] + + +class CelestialFrame(CoordinateFrame): + """ + Representation of a Celesital coordinate system. + + This class has a native order of longitude then latitude, meaning + ``axes_names``, ``unit`` and ``axis_physical_types`` should be lon, lat + ordered. If your transform is in a different order this should be specified + with ``axes_order``. + + Parameters + ---------- + axes_order + A dimension in the input data that corresponds to this axis. + reference_frame + A reference frame. + unit + Units on axes. + axes_names + Names of the axes in this frame. + name + Name of this frame. + axis_physical_types + The UCD 1+ physical types for the axes, in frame order (lon, lat). + """ + + def __init__( + self, + axes_order: tuple[int, ...] | None = None, + reference_frame: BaseCoordinateFrame | None = None, + unit: tuple[u.Unit, ...] | None = None, + axes_names: tuple[str, ...] | None = None, + name: str | None = None, + axis_physical_types: AxisPhysicalTypes | None = None, + ) -> None: + naxes = 2 + if ( + reference_frame is not None + and not isinstance(reference_frame, str) + and reference_frame.name.upper() in STANDARD_REFERENCE_FRAMES + ): + _axes_names = ( + tuple( + n + for n in reference_frame.representation_component_names.values() + if n != "distance" + ) + if axes_names is None + else axes_names + ) + naxes = len(_axes_names) + + self.native_axes_order = tuple(range(naxes)) + if axes_order is None: + axes_order = self.native_axes_order + if unit is None: + # Astropy dynamically creates some units, so MyPy can't find them + unit = tuple([u.degree] * naxes) # type: ignore[attr-defined] + axes_type = (AxisType.SPATIAL,) * naxes + + super().__init__( + naxes=naxes, + axes_type=axes_type, + axes_order=axes_order, + reference_frame=reference_frame, + unit=unit, + axes_names=_axes_names, + name=name, + axis_physical_types=axis_physical_types, + ) + + def _default_axis_physical_types( + self, properties: FrameProperties + ) -> AxisPhysicalTypes: + if isinstance(self.reference_frame, Galactic): + return "pos.galactic.lon", "pos.galactic.lat" + if isinstance( + self.reference_frame, + GeocentricTrueEcliptic | GCRS | PrecessedGeocentric, + ): + return "pos.bodyrc.lon", "pos.bodyrc.lat" + if isinstance(self.reference_frame, builtin_frames.BaseRADecFrame): + return "pos.eq.ra", "pos.eq.dec" + if isinstance(self.reference_frame, builtin_frames.BaseEclipticFrame): + return "pos.ecliptic.lon", "pos.ecliptic.lat" + return tuple(f"custom:{t}" for t in properties.axes_names) + + @property + def world_axis_object_classes(self) -> WorldAxisClasses: + return { + "celestial": WorldAxisClass( + SkyCoord, + (), + {"frame": self.reference_frame, "unit": self._prop.unit}, + ) + } + + @property + def _native_world_axis_object_components(self) -> WorldAxisComponents: + return [ + WorldAxisComponent( + "celestial", 0, lambda sc: sc.spherical.lon.to_value(self._prop.unit[0]) + ), + WorldAxisComponent( + "celestial", 1, lambda sc: sc.spherical.lat.to_value(self._prop.unit[1]) + ), + ] diff --git a/gwcs/coordinate_frames/_composite.py b/gwcs/coordinate_frames/_composite.py new file mode 100644 index 00000000..e9f35d46 --- /dev/null +++ b/gwcs/coordinate_frames/_composite.py @@ -0,0 +1,149 @@ +from collections import defaultdict +from collections.abc import Generator +from typing import Any, cast + +import numpy as np +from astropy import units as u + +from gwcs._typing import AxisPhysicalType +from gwcs.api import ( + WorldAxisClass, + WorldAxisClasses, + WorldAxisComponent, + WorldAxisComponents, + WorldAxisConverterClass, +) + +from ._axis import AxisType +from ._base import BaseCoordinateFrame +from ._core import CoordinateFrame + +__all__ = ["CompositeFrame"] + + +class CompositeFrame(CoordinateFrame): + """ + Represents one or more frames. + + Parameters + ---------- + frames + List of constituient frames. + name + Name for this frame. + """ + + def __init__( + self, frames: list[BaseCoordinateFrame], name: str | None = None + ) -> None: + self._frames = frames[:] + naxes = sum([frame.naxes for frame in self._frames]) + + axes_order: list[int] = [] + axes_type: list[AxisType | str] = [] + axes_names: list[str] = [] + unit: list[u.Unit] = [] + ph_type: list[AxisPhysicalType] = [] + + for frame in frames: + axes_order.extend(frame.axes_order) + + # Stack the raw (not-native) ordered properties + for frame in frames: + axes_type += list(frame._prop.axes_type) + axes_names += list(frame._prop.axes_names) + # no common base class in astropy.units for all units + unit += list(frame._prop.unit) # type: ignore[arg-type] + ph_type += list(frame._prop.axis_physical_types) + + if len(np.unique(axes_order)) != len(axes_order): + msg = ( + "Incorrect numbering of axes, " + "axes_order should contain unique numbers, " + f"got {axes_order}." + ) + raise ValueError(msg) + + super().__init__( + naxes, + axes_type=tuple(axes_type), + axes_order=tuple(axes_order), + unit=tuple(unit), + axes_names=tuple(axes_names), + axis_physical_types=tuple(ph_type), + name=name, + ) + self._axis_physical_types = tuple(ph_type) + + @property + def frames(self) -> list[BaseCoordinateFrame]: + """ + The constituient frames that comprise this `CompositeFrame`. + """ + return self._frames + + def __repr__(self) -> str: + return repr(self.frames) + + @property + def _wao_classes_rename_map(self) -> dict[Any, Any]: + mapper: dict[Any, Any] = defaultdict(dict) + seen_names: list[str] = [] + for frame in self.frames: + # ensure the frame is in the mapper + mapper[frame] + for key in frame.world_axis_object_classes: + key = cast(str, key) + if key in seen_names: + new_key = f"{key}{seen_names.count(key)}" + mapper[frame][key] = new_key + seen_names.append(key) + return mapper + + @property + def _wao_renamed_components_iter( + self, + ) -> Generator[tuple[BaseCoordinateFrame, list[WorldAxisComponent]], None, None]: + mapper = self._wao_classes_rename_map + for frame in self.frames: + renamed_components: list[WorldAxisComponent] = [] + for component in frame._native_world_axis_object_components: + rename: str = mapper[frame].get(component[0], component[0]) + renamed_components.append( + WorldAxisComponent(rename, component[1], component[2]) + ) + + yield frame, renamed_components + + @property + def _wao_renamed_classes_iter( + self, + ) -> Generator[tuple[str, WorldAxisClass | WorldAxisConverterClass], None, None]: + mapper = self._wao_classes_rename_map + for frame in self.frames: + for key, value in frame.world_axis_object_classes.items(): + rename: str = mapper[frame].get(key, key) + yield rename, value + + @property + def world_axis_object_components(self) -> WorldAxisComponents: + """ + Object components for this frame. + """ + out: list[WorldAxisComponent | None] = [None] * self.naxes + + for frame, components in self._wao_renamed_components_iter: + for i, ao in enumerate(frame.axes_order): + out[ao] = components[i] + + if any(o is None for o in out): + msg = "axes_order leads to incomplete world_axis_object_components" + raise ValueError(msg) + + # There can be None in the list here, but this is unique to this and + # annoying otherwise, so we ignore MyPy + return out # type: ignore[return-value] + + @property + def world_axis_object_classes(self) -> WorldAxisClasses: + return dict(self._wao_renamed_classes_iter) diff --git a/gwcs/coordinate_frames/_core.py b/gwcs/coordinate_frames/_core.py new file mode 100644 index 00000000..8c4fb799 --- /dev/null +++ b/gwcs/coordinate_frames/_core.py @@ -0,0 +1,263 @@ +import numbers +from typing import TypeVar + +import numpy as np +from astropy import units as u +from astropy.coordinates import BaseCoordinateFrame as _BaseCoordinateFrame +from astropy.time import Time +from astropy.wcs.wcsapi.high_level_api import ( + high_level_objects_to_values, + values_to_high_level_objects, +) + +from gwcs._typing import ( + AxisPhysicalTypes, + HighLevelObject, + HighLevelObjects, + LowLevelArrays, + LowLevelValue, +) +from gwcs.api import ( + WorldAxisClass, + WorldAxisClasses, + WorldAxisComponent, + WorldAxisComponents, +) + +from ._axis import AxesType +from ._base import BaseCoordinateFrame +from ._properties import FrameProperties + +__all__ = ["CoordinateFrame"] + +_T = TypeVar("_T") + + +class CoordinateFrame(BaseCoordinateFrame): + """ + Base class for Coordinate Frames. + + Parameters + ---------- + naxes + Number of axes. + axes_type + One of ["SPATIAL", "SPECTRAL", "TIME"] + axes_order + A dimension in the input data that corresponds to this axis. + reference_frame + Reference frame (usually used with output_frame to convert to world + coordinate objects). + unit + Unit for each axis. + axes_names + Names of the axes in this frame. + name + Name of this frame. + """ + + def __init__( + self, + naxes: int, + axes_type: AxesType, + axes_order: tuple[int, ...], + reference_frame: _BaseCoordinateFrame | Time | None = None, + unit: tuple[u.Unit, ...] | None = None, + axes_names: tuple[str, ...] | None = None, + name: str | None = None, + axis_physical_types: AxisPhysicalTypes | None = None, + ) -> None: + self._naxes = naxes + self._axes_order = tuple(axes_order) + self._reference_frame = reference_frame + + if name is None: + self._name = type(self).__name__ + else: + self._name = name + + if len(self._axes_order) != naxes: + msg = "Length of axes_order does not match number of axes." + raise ValueError(msg) + + if isinstance(axes_type, str): + axes_type = (axes_type,) + + self._prop = FrameProperties( + naxes, + axes_type, + unit, + axes_names, + axis_physical_types, + self._default_axis_physical_types, + ) + + super().__init__() + + def _default_axis_physical_types( + self, properties: FrameProperties + ) -> AxisPhysicalTypes: + """ + The default physical types to use for this frame if none are specified + by the user. + """ + return tuple(f"custom:{t}" for t in properties.axes_type) + + def __repr__(self) -> str: + fmt = ( + f'<{self.__class__.__name__}(name="{self.name}", unit={self.unit}, ' + f"axes_names={self.axes_names}, axes_order={self.axes_order}" + ) + if self.reference_frame is not None: + fmt += f", reference_frame={self.reference_frame}" + fmt += ")>" + return fmt + + def __str__(self) -> str: + return self._name + + def _sort_property(self, prop: tuple[_T, ...]) -> tuple[_T, ...]: + sorted_prop = sorted( + zip(prop, self.axes_order, strict=False), key=lambda x: x[1] + ) + return tuple([t[0] for t in sorted_prop]) + + @property + def name(self) -> str: + """A custom name of this frame.""" + return self._name + + @name.setter + def name(self, val: str) -> None: + """A custom name of this frame.""" + self._name = val + + @property + def naxes(self) -> int: + """The number of axes in this frame.""" + return self._naxes + + @property + def unit(self) -> tuple[u.Unit, ...]: + """The unit of this frame.""" + return self._sort_property(self._prop.unit) # type: ignore[arg-type] + + @property + def axes_names(self) -> tuple[str, ...]: + """Names of axes in the frame.""" + return self._sort_property(self._prop.axes_names) + + @property + def axes_order(self) -> tuple[int, ...]: + """A tuple of indices which map inputs to axes.""" + return self._axes_order + + @property + def reference_frame(self) -> _BaseCoordinateFrame | Time | None: + """Reference frame, used to convert to world coordinate objects.""" + return self._reference_frame + + @property + def axes_type(self) -> AxesType: + """Type of this frame : 'SPATIAL', 'SPECTRAL', 'TIME'.""" + return self._sort_property(self._prop.axes_type) + + @property + def axis_physical_types(self) -> AxisPhysicalTypes: + """ + The axis physical types for this frame. + + These physical types are the types in frame order, not transform order. + """ + return self._sort_property(self._prop.axis_physical_types) + + @property + def world_axis_object_classes(self) -> WorldAxisClasses: + return { + f"{at}{i}" if i != 0 else str(at): WorldAxisClass( + u.Quantity, (), {"unit": unit} + ) + for i, (at, unit) in enumerate(zip(self.axes_type, self.unit, strict=False)) + } + + @property + def _native_world_axis_object_components(self) -> WorldAxisComponents: + return [ + WorldAxisComponent(f"{at}{i}" if i != 0 else at, 0, "value") + for i, at in enumerate(self._prop.axes_type) + ] + + @property + def serialized_classes(self) -> bool: + """ + This property is used by the low level WCS API in Astropy. + + By providing it we can duck type as a low level WCS object. + """ + return False + + def to_high_level_coordinates(self, *values: LowLevelValue) -> HighLevelObjects: + """ + Convert "values" to high level coordinate objects described by this frame. + + "values" are the coordinates in array or scalar form, and high level + objects are things such as ``SkyCoord`` or ``Quantity``. See + :ref:`wcsapi` for details. + + Parameters + ---------- + values : `numbers.Number`, `numpy.ndarray`, or `~astropy.units.Quantity` + ``naxis`` number of coordinates as scalars or arrays. + + Returns + ------- + high_level_coordinates + One (or more) high level object describing the coordinate. + """ + # We allow Quantity-like objects here which values_to_high_level_objects + # does not. + values = tuple( + v.to_value(unit) if hasattr(v, "to_value") else v + for v, unit in zip(values, self.unit, strict=False) + ) + + if not all( + isinstance(v, numbers.Number) or type(v) is np.ndarray for v in values + ): + msg = "All values should be a scalar number or a numpy array." + raise TypeError(msg) + + high_level: list[HighLevelObject] = values_to_high_level_objects( + *values, low_level_wcs=self + ) # type: ignore[no-untyped-call] + if len(high_level) == 1: + return high_level[0] + + return tuple(high_level) + + def from_high_level_coordinates( + self, *high_level_coords: HighLevelObject + ) -> LowLevelArrays: + """ + Convert high level coordinate objects to "values" as described by this frame. + + "values" are the coordinates in array or scalar form, and high level + objects are things such as ``SkyCoord`` or ``Quantity``. See + :ref:`wcsapi` for details. + + Parameters + ---------- + high_level_coordinates + One (or more) high level object describing the coordinate. + + Returns + ------- + values : `numbers.Number` or `numpy.ndarray` + ``naxis`` number of coordinates as scalars or arrays. + """ + values: list[LowLevelValue] = high_level_objects_to_values( + *high_level_coords, low_level_wcs=self + ) # type: ignore[no-untyped-call] + if len(values) == 1: + return values[0] + return tuple(values) diff --git a/gwcs/coordinate_frames/_empty.py b/gwcs/coordinate_frames/_empty.py new file mode 100644 index 00000000..9e5e2101 --- /dev/null +++ b/gwcs/coordinate_frames/_empty.py @@ -0,0 +1,76 @@ +from astropy import units as u +from astropy.coordinates import BaseCoordinateFrame as _BaseCoordinateFrame + +from gwcs._typing import AxisPhysicalTypes +from gwcs.api import WorldAxisClasses, WorldAxisComponents + +from ._axis import AxesType +from ._core import CoordinateFrame + +__all__ = ["EmptyFrame"] + + +class EmptyFrame(CoordinateFrame): + """ + Represents a "default" detector frame. This is for use as the default value + for input frame by the WCS object. + """ + + def __init__(self, name: str | None = None) -> None: + self._name = "detector" if name is None else name + + def __repr__(self) -> str: + return f'<{type(self).__name__}(name="{self.name}")>' + + def __str__(self) -> str: + return self._name + + @property + def name(self) -> str: + """A custom name of this frame.""" + return self._name + + @name.setter + def name(self, val: str) -> None: + """A custom name of this frame.""" + self._name = val + + def _raise_error(self) -> None: + msg = "EmptyFrame does not have any information" + raise NotImplementedError(msg) + + @property + def naxes(self) -> int: # type: ignore[return] + self._raise_error() + + @property + def unit(self) -> tuple[u.Unit, ...]: # type: ignore[return] + self._raise_error() + + @property + def axes_names(self) -> tuple[str, ...]: # type: ignore[return] + self._raise_error() + + @property + def axes_order(self) -> tuple[int, ...]: # type: ignore[return] + self._raise_error() + + @property + def reference_frame(self) -> _BaseCoordinateFrame | None: # type: ignore[return] + self._raise_error() + + @property + def axes_type(self) -> AxesType: # type: ignore[return] + self._raise_error() + + @property + def axis_physical_types(self) -> AxisPhysicalTypes: # type: ignore[return] + self._raise_error() + + @property + def world_axis_object_classes(self) -> WorldAxisClasses: # type: ignore[return] + self._raise_error() + + @property + def _native_world_axis_object_components(self) -> WorldAxisComponents: # type: ignore[return] + self._raise_error() diff --git a/gwcs/coordinate_frames/_frame.py b/gwcs/coordinate_frames/_frame.py new file mode 100644 index 00000000..98fdbdc0 --- /dev/null +++ b/gwcs/coordinate_frames/_frame.py @@ -0,0 +1,57 @@ +from astropy import units as u + +from gwcs._typing import AxisPhysicalTypes + +from ._axis import AxesType, AxisType +from ._core import CoordinateFrame +from ._properties import FrameProperties + + +class Frame2D(CoordinateFrame): + """ + A 2D coordinate frame. + + Parameters + ---------- + axes_order + A dimension in the input data that corresponds to this axis. + unit + Unit for each axis. + axes_name + Names of the axes in this frame. + name + Name of this frame. + """ + + def __init__( + self, + axes_order: tuple[int, ...] = (0, 1), + # Astropy dynamically builds these types at runtime, so MyPy can't find them + unit: tuple[u.Unit, ...] = (u.pix, u.pix), # type: ignore[attr-defined] + axes_names: tuple[str, ...] = ("x", "y"), + name: str | None = None, + axes_type: AxesType | None = None, + axis_physical_types: AxisPhysicalTypes | None = None, + ) -> None: + if axes_type is None: + axes_type = (AxisType.SPATIAL, AxisType.SPATIAL) + + super().__init__( + naxes=2, + axes_type=axes_type, + axes_order=axes_order, + name=name, + axes_names=axes_names, + unit=unit, + axis_physical_types=axis_physical_types, + ) + + def _default_axis_physical_types( + self, properties: FrameProperties + ) -> AxisPhysicalTypes: + if all(properties.axes_names): + ph_type = properties.axes_names + else: + ph_type = properties.axes_type + + return tuple(f"custom:{t}" for t in ph_type) diff --git a/gwcs/coordinate_frames/_properties.py b/gwcs/coordinate_frames/_properties.py new file mode 100644 index 00000000..9c12cdde --- /dev/null +++ b/gwcs/coordinate_frames/_properties.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from collections.abc import Callable + +from astropy.units import Unit, dimensionless_unscaled +from astropy.utils.misc import isiterable +from astropy.wcs.wcsapi.low_level_api import VALID_UCDS, validate_physical_types + +from gwcs._typing import AxisPhysicalType, AxisPhysicalTypes + +from ._axis import AxesType + +__all__ = ["FrameProperties"] + + +class FrameProperties: + def __init__( + self, + naxes: int, + axes_type: AxesType, + unit: tuple[Unit, ...] | None = None, + axes_names: str | tuple[str, ...] | None = None, + axis_physical_types: AxisPhysicalType | AxisPhysicalTypes | None = None, + default_axis_physical_types: Callable[[FrameProperties], AxisPhysicalTypes] + | None = None, + ) -> None: + self.naxes = naxes + + self.axes_type = ( + (axes_type,) if isinstance(axes_type, str) else tuple(axes_type) + ) + + if len(self.axes_type) != naxes: + msg = "Length of axes_type does not match number of axes." + raise ValueError(msg) + + if unit is None: + self.unit = tuple(dimensionless_unscaled for _ in range(naxes)) + else: + unit_ = tuple(unit) if isiterable(unit) else (unit,) # type: ignore[no-untyped-call] + if len(unit_) != naxes: + msg = "Number of units does not match number of axes." + raise ValueError(msg) + self.unit = tuple(Unit(au) for au in unit_) # type: ignore[no-untyped-call, misc] + + if axes_names is None: + self.axes_names = tuple([""] * naxes) + else: + self.axes_names = ( + (axes_names,) if isinstance(axes_names, str) else tuple(axes_names) + ) + if len(self.axes_names) != naxes: + msg = "Number of axes names does not match number of axes." + raise ValueError(msg) + + if axis_physical_types is None: + if default_axis_physical_types is None: + default_axis_physical_types = self._default_axis_physical_types + + self.axis_physical_types: AxisPhysicalTypes = default_axis_physical_types( + self + ) + else: + self.axis_physical_types = ( + (axis_physical_types,) + if isinstance(axis_physical_types, str) + else tuple(axis_physical_types) + ) + + if len(self.axis_physical_types) != naxes: + msg = f'"axis_physical_types" must be of length {naxes}' + raise ValueError(msg) + + self.axis_physical_types = tuple( + f"custom:{axt}" + if axt not in VALID_UCDS and not axt.startswith("custom:") + else axt + for axt in self.axis_physical_types + ) + validate_physical_types(self.axis_physical_types) # type: ignore[no-untyped-call] + + @staticmethod + def _default_axis_physical_types(properties: FrameProperties) -> AxisPhysicalTypes: + """ + The default physical types to use for this frame if none are specified + by the user. + """ + return tuple(f"custom:{t}" for t in properties.axes_type) diff --git a/gwcs/coordinate_frames/_spectral.py b/gwcs/coordinate_frames/_spectral.py new file mode 100644 index 00000000..8dcce4c7 --- /dev/null +++ b/gwcs/coordinate_frames/_spectral.py @@ -0,0 +1,78 @@ +from astropy import units as u +from astropy.coordinates import BaseCoordinateFrame, SpectralCoord + +from gwcs._typing import AxisPhysicalTypes +from gwcs.api import ( + WorldAxisClass, + WorldAxisClasses, + WorldAxisComponent, + WorldAxisComponents, +) + +from ._axis import AxisType +from ._core import CoordinateFrame +from ._properties import FrameProperties + +__all__ = ["SpectralFrame"] + + +class SpectralFrame(CoordinateFrame): + """ + Represents Spectral Frame + + Parameters + ---------- + axes_order + A dimension in the input data that corresponds to this axis. + reference_frame : astropy.coordinates.builtin_frames + Reference frame (usually used with output_frame to convert to world + coordinate objects). + unit + Spectral unit. + axes_names + Spectral axis name. + name + Name for this frame. + + """ + + def __init__( + self, + axes_order: tuple[int, ...] = (0,), + reference_frame: BaseCoordinateFrame | None = None, + unit: tuple[u.Unit, ...] | None = None, + axes_names: tuple[str, ...] | None = None, + name: str | None = None, + axis_physical_types: AxisPhysicalTypes | None = None, + ) -> None: + super().__init__( + naxes=1, + axes_type=AxisType.SPECTRAL, + axes_order=axes_order, + axes_names=axes_names, + reference_frame=reference_frame, + unit=unit, + name=name, + axis_physical_types=axis_physical_types, + ) + + def _default_axis_physical_types( + self, properties: FrameProperties + ) -> AxisPhysicalTypes: + if properties.unit[0].physical_type == "frequency": + return ("em.freq",) + if properties.unit[0].physical_type == "length": + return ("em.wl",) + if properties.unit[0].physical_type == "energy": + return ("em.energy",) + if properties.unit[0].physical_type == "speed": + return ("spect.dopplerVeloc",) + return (f"custom:{properties.unit[0].physical_type}",) + + @property + def world_axis_object_classes(self) -> WorldAxisClasses: + return {"spectral": WorldAxisClass(SpectralCoord, (), {"unit": self.unit[0]})} + + @property + def _native_world_axis_object_components(self) -> WorldAxisComponents: + return [WorldAxisComponent("spectral", 0, lambda sc: sc.to_value(self.unit[0]))] diff --git a/gwcs/coordinate_frames/_stokes.py b/gwcs/coordinate_frames/_stokes.py new file mode 100644 index 00000000..20aa9f78 --- /dev/null +++ b/gwcs/coordinate_frames/_stokes.py @@ -0,0 +1,65 @@ +from astropy import units as u +from astropy.coordinates import StokesCoord + +from gwcs._typing import AxisPhysicalTypes +from gwcs.api import ( + WorldAxisClass, + WorldAxisClasses, + WorldAxisComponent, + WorldAxisComponents, +) + +from ._axis import AxisType +from ._core import CoordinateFrame +from ._properties import FrameProperties + +__all__ = ["StokesFrame"] + + +class StokesFrame(CoordinateFrame): + """ + A coordinate frame for representing Stokes polarisation states. + + Parameters + ---------- + name + Name of this frame. + axes_order + A dimension in the data that corresponds to this axis. + """ + + def __init__( + self, + axes_order: tuple[int, ...] = (0,), + axes_names: tuple[str, ...] = ("stokes",), + name: str | None = None, + axis_physical_types: AxisPhysicalTypes | None = None, + ) -> None: + super().__init__( + 1, + (AxisType.STOKES,), + axes_order, + name=name, + axes_names=axes_names, + unit=(u.one,), # type: ignore[arg-type] + axis_physical_types=axis_physical_types, + ) + + def _default_axis_physical_types( + self, properties: FrameProperties + ) -> AxisPhysicalTypes: + return ("phys.polarization.stokes",) + + @property + def world_axis_object_classes(self) -> WorldAxisClasses: + return { + "stokes": WorldAxisClass( + StokesCoord, + (), + {}, + ) + } + + @property + def _native_world_axis_object_components(self) -> WorldAxisComponents: + return [WorldAxisComponent("stokes", 0, "value")] diff --git a/gwcs/coordinate_frames/_temporal.py b/gwcs/coordinate_frames/_temporal.py new file mode 100644 index 00000000..3b527151 --- /dev/null +++ b/gwcs/coordinate_frames/_temporal.py @@ -0,0 +1,117 @@ +import contextlib +from typing import Any, cast + +import numpy as np +from astropy import units as u +from astropy.time import Time, TimeDelta + +from gwcs._typing import AxisPhysicalTypes +from gwcs.api import ( + WorldAxisClasses, + WorldAxisComponent, + WorldAxisComponents, + WorldAxisConverterClass, +) + +from ._axis import AxisType +from ._core import CoordinateFrame +from ._properties import FrameProperties + +__all__ = ["TemporalFrame"] + + +class TemporalFrame(CoordinateFrame): + """ + A coordinate frame for time axes. + + Parameters + ---------- + reference_frame + A Time object which holds the time scale and format. + If data is provided, it is the time zero point. + To not set a zero point for the frame initialize ``reference_frame`` + with an empty list. + unit + Time unit. + axes_names + Time axis name. + axes_order + A dimension in the data that corresponds to this axis. + name + Name for this frame. + """ + + def __init__( + self, + reference_frame: Time, + # Astropy dynamically builds these types at runtime, so MyPy can't find them + unit: tuple[u.Unit, ...] = (u.s,), # type: ignore[attr-defined] + axes_order: tuple[int, ...] = (0,), + axes_names: tuple[str, ...] | None = None, + name: str | None = None, + axis_physical_types: AxisPhysicalTypes | None = None, + ) -> None: + _axes_names = ( + ( + f"{reference_frame.format}({reference_frame.scale}; " + f"{reference_frame.location}", + ) + if axes_names is None + else axes_names + ) + + super().__init__( + naxes=1, + axes_type=AxisType.TIME, + axes_order=axes_order, + axes_names=_axes_names, + reference_frame=reference_frame, + unit=unit, + name=name, + axis_physical_types=axis_physical_types, + ) + self._attrs = {} + for a in self.reference_frame.info._represent_as_dict_extra_attrs: + with contextlib.suppress(AttributeError): + self._attrs[a] = getattr(self.reference_frame, a) + + @property + def reference_frame(self) -> Time: + return cast(Time, self._reference_frame) + + def _default_axis_physical_types( + self, properties: FrameProperties + ) -> AxisPhysicalTypes: + return ("time",) + + def _convert_to_time(self, dt: Any, *, unit: u.Unit, **kwargs: Any) -> Time: + if (not isinstance(dt, TimeDelta) and isinstance(dt, Time)) or isinstance( + self.reference_frame.value, np.ndarray + ): + return Time(dt, **kwargs) # type: ignore[no-untyped-call] + + if not hasattr(dt, "unit"): + dt = dt * unit + + return cast(Time, self.reference_frame + dt) + + @property + def world_axis_object_classes(self) -> WorldAxisClasses: + return { + "temporal": WorldAxisConverterClass( + Time, + (), + {"unit": self.unit[0], **self._attrs}, + self._convert_to_time, + ) + } + + @property + def _native_world_axis_object_components(self) -> WorldAxisComponents: + if isinstance(self.reference_frame.value, np.ndarray): + return [WorldAxisComponent("temporal", 0, "value")] + + def offset_from_time_and_reference(time: Time) -> Any: + return (time - self.reference_frame).sec + + return [WorldAxisComponent("temporal", 0, offset_from_time_and_reference)] diff --git a/gwcs/coordinate_frames/_utils.py b/gwcs/coordinate_frames/_utils.py new file mode 100644 index 00000000..2b9c71c1 --- /dev/null +++ b/gwcs/coordinate_frames/_utils.py @@ -0,0 +1,60 @@ +# Licensed under a 3-clause BSD style license - see LICENSE.rst + +import logging + +from astropy.wcs.wcsapi.fitswcs import CTYPE_TO_UCD1 + +__all__ = ["get_ctype_from_ucd"] + + +def _ucd1_to_ctype_name_mapping( + ctype_to_ucd: dict[str, str], allowed_ucd_duplicates: dict[str, str] +) -> dict[str, str]: + inv_map = {} + new_ucd = set() + + for kwd, ucd in ctype_to_ucd.items(): + if ucd in inv_map: + if ucd not in allowed_ucd_duplicates: + new_ucd.add(ucd) + continue + inv_map[ucd] = allowed_ucd_duplicates.get(ucd, kwd) + + if new_ucd: + logging.warning( + "Found unsupported duplicate physical type in 'astropy' mapping to CTYPE.\n" + "Update 'gwcs' to the latest version or notify 'gwcs' developer.\n" + "Duplicate physical types will be mapped to the following CTYPEs:\n" + + "\n".join([f"{ucd!r:s} --> {inv_map[ucd]!r:s}" for ucd in new_ucd]) + ) + + return inv_map + + +# List below allowed physical type duplicates and a corresponding CTYPE +# to which all duplicates will be mapped to: +_ALLOWED_UCD_DUPLICATES = { + "time": "TIME", + "em.wl": "WAVE", +} + +UCD1_TO_CTYPE = _ucd1_to_ctype_name_mapping( + ctype_to_ucd=CTYPE_TO_UCD1, allowed_ucd_duplicates=_ALLOWED_UCD_DUPLICATES +) + + +def get_ctype_from_ucd(ucd: str) -> str: + """ + Return the FITS ``CTYPE`` corresponding to a UCD1 value. + + Parameters + ---------- + ucd : str + UCD string, for example one of ```WCS.world_axis_physical_types``. + + Returns + ------- + CTYPE : str + The corresponding FITS ``CTYPE`` value or an empty string. + """ + return UCD1_TO_CTYPE.get(ucd, "") diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index fd6cf834..96948638 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -16,6 +16,10 @@ from gwcs import WCS from gwcs import coordinate_frames as cf +from gwcs.coordinate_frames._utils import ( + _ALLOWED_UCD_DUPLICATES, + _ucd1_to_ctype_name_mapping, +) astropy_version = astropy.__version__ @@ -102,9 +106,9 @@ def coordinates(*inputs, frame): def coordinate_to_quantity(*inputs, frame): results = frame.from_high_level_coordinates(*inputs) - if not isinstance(results, list): - results = [results] - return [r << unit for r, unit in zip(results, frame.unit, strict=False)] + if not isinstance(results, tuple): + results = (results,) + return tuple(r << unit for r, unit in zip(results, frame.unit, strict=False)) @pytest.mark.parametrize("inputs", inputs2) @@ -520,8 +524,8 @@ def test_ucd1_to_ctype_not_out_of_sync(caplog): dictionary with new types defined in ``astropy``'s ``CTYPE_TO_UCD1``. """ - cf._ucd1_to_ctype_name_mapping( - ctype_to_ucd=CTYPE_TO_UCD1, allowed_ucd_duplicates=cf._ALLOWED_UCD_DUPLICATES + _ucd1_to_ctype_name_mapping( + ctype_to_ucd=CTYPE_TO_UCD1, allowed_ucd_duplicates=_ALLOWED_UCD_DUPLICATES ) assert len(caplog.record_tuples) == 0 @@ -536,8 +540,8 @@ def test_ucd1_to_ctype(caplog): ctype_to_ucd = dict(**CTYPE_TO_UCD1, **new_ctype_to_ucd) - inv_map = cf._ucd1_to_ctype_name_mapping( - ctype_to_ucd=ctype_to_ucd, allowed_ucd_duplicates=cf._ALLOWED_UCD_DUPLICATES + inv_map = _ucd1_to_ctype_name_mapping( + ctype_to_ucd=ctype_to_ucd, allowed_ucd_duplicates=_ALLOWED_UCD_DUPLICATES ) assert caplog.record_tuples[-1][1] == logging.WARNING @@ -545,7 +549,7 @@ def test_ucd1_to_ctype(caplog): "Found unsupported duplicate physical type" ) - for k, v in cf._ALLOWED_UCD_DUPLICATES.items(): + for k, v in _ALLOWED_UCD_DUPLICATES.items(): assert inv_map.get(k, "") == v for k, v in inv_map.items(): diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index 37918bbb..53ba5b49 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -1813,3 +1813,58 @@ def test_direct_numerical_inverse(gwcs_romanisim): out = gwcs_romanisim.numerical_inverse(*ra_dec) assert_allclose(xy, out) + + +def test_array_high_level_output(): + """ + Test that we don't loose array values when requesting a high-level output + from a WCS object. + """ + input_frame = cf.CoordinateFrame( + naxes=1, + axes_type=("SPATIAL",), + axes_order=(0,), + name="pixels", + unit=(u.pix,), + axes_names=("x",), + ) + output_frame = cf.SpectralFrame(unit=(u.nm,), axes_names=("lambda",)) + wave_model = models.Scale(0.1) | models.Shift(500) + gwcs = wcs.WCS([(input_frame, wave_model), (output_frame, None)]) + assert ( + gwcs(np.array([0, 1, 2]), with_units=True) + == coord.SpectralCoord([500, 500.1, 500.2] * u.nm) + ).all() + + +def test_parameterless_transform(): + """ + Test that a transform with no parameters correctly handles units. + -> The wcs does not introduce units when evaluating the forward or backward + transform for models with no parameters + Regression test for #558 + """ + + in_frame = cf.Frame2D(name="in_frame") + out_frame = cf.Frame2D(name="out_frame") + + gwcs = wcs.WCS( + [ + (in_frame, models.Identity(2)), + (out_frame, None), + ] + ) + + # The expectation for this wcs is that: + # - gwcs(1, 1) has no units + # (__call__ apparently is supposed to pass units through?) + # - gwcs(1*u.pix, 1*u.pix) has units + # - gwcs.invert(1, 1) has no units + # - gwcs.invert(1*u.pix, 1*u.pix) has no units + + # No units introduced by the forward transform + assert gwcs(1, 1) == (1, 1) + assert gwcs(1 * u.pix, 1 * u.pix) == (1 * u.pix, 1 * u.pix) + + assert gwcs.invert(1, 1) == (1, 1) + assert gwcs.invert(1 * u.pix, 1 * u.pix) == (1, 1) diff --git a/gwcs/wcs/_wcs.py b/gwcs/wcs/_wcs.py index 51a5790f..bf3fc067 100644 --- a/gwcs/wcs/_wcs.py +++ b/gwcs/wcs/_wcs.py @@ -6,7 +6,6 @@ import astropy.units as u import numpy as np -from astropy import utils as astutil from astropy.io import fits from astropy.modeling import fix_inputs, projections from astropy.modeling.bounding_box import ModelBoundingBox as Bbox @@ -117,13 +116,10 @@ def __init__( self._pixel_shape = None def _add_units_input( - self, arrays: list[np.ndarray], frame: CoordinateFrame | None + self, arrays: np.ndarray | float, frame: CoordinateFrame | None ) -> tuple[u.Quantity, ...]: if frame is not None: - return tuple( - u.Quantity(array, unit) - for array, unit in zip(arrays, frame.unit, strict=False) - ) + return frame.add_units(arrays) return arrays @@ -131,10 +127,7 @@ def _remove_units_input( self, arrays: list[u.Quantity], frame: CoordinateFrame | None ) -> tuple[np.ndarray, ...]: if frame is not None: - return tuple( - array.to_value(unit) if isinstance(array, u.Quantity) else array - for array, unit in zip(arrays, frame.unit, strict=False) - ) + return frame.remove_units(arrays) return arrays @@ -166,10 +159,7 @@ def __call__( results = self._call_forward( *args, with_bounding_box=with_bounding_box, fill_value=fill_value, **kwargs ) - if with_units: - if not astutil.isiterable(results): - results = (results,) # values are always expected to be arrays or scalars not quantities results = self._remove_units_input(results, self.output_frame) high_level = values_to_high_level_objects(*results, low_level_wcs=self) @@ -178,6 +168,85 @@ def __call__( return high_level return results + def _evaluate_transform( + self, + transform, + from_frame, + to_frame, + *args, + with_bounding_box: bool = True, + fill_value: float | np.number = np.nan, + **kwargs, + ): + """ + Introduces or removes units from the arguments as need so that the transform + can be successfully evaluated. + + Notes + ----- + Much of the logic in this method is due to the unfortunate fact that the + `uses_quantity` property for models is not reliable for determining if one + must pass quantities or not. It instead tells you: + 1. If it has any parameter that is a quantity + 2. It defaults to true for parameterless models. + + This is problematic because its entirely possible to construct a model with + a parameter that is a quantity but the model itself either doesn't require + them or in fact cannot use them. This is a very rare case but it could happen. + Currently, this case is not handled, but it is worth noting in case it comes up + + The more problematic case is for parameterless models. `uses_quantity` assumes + that if there are no parameters, then the model is agnostic to quantity inputs. + This is an incorrect assumption, even with in `astropy.modeling`'s built in + models. The `Tabular1D` model for example has no "parameters" but it can + require quantities if its "points" construction input is a quantity. This + is the main case for the try/except block in this method. + + Properly dealing with this will require upstream work in `astropy.modeling` + which is outside the scope of what GWCS can control. + + to_frame is included as we really ought to be stripping the result of units + but we currently are not. API refactor should include addressing this. + """ + + # Validate that the input type matches what the transform expects + input_is_quantity = any(isinstance(a, u.Quantity) for a in args) + + def _transform(*args): + """Wrap the transform evaluation""" + + return transform( + *args, + with_bounding_box=with_bounding_box, + fill_value=fill_value, + **kwargs, + ) + + # Models with no parameters claim they use quantities but this may incorrectly + # introduce units so we don't at first + if ( + not input_is_quantity + and transform.uses_quantity + and transform.parameters.size + ): + args = self._add_units_input(args, from_frame) + if not transform.uses_quantity and input_is_quantity: + args = self._remove_units_input(args, from_frame) + + try: + return _transform(*args) + except u.UnitsError: + # In this case we are handling parameterless models that require units + # to function correctly. + if ( + not input_is_quantity + and transform.uses_quantity + and not transform.parameters.size + ): + return _transform(*self._add_units_input(args, from_frame)) + + raise + def _call_forward( self, *args, @@ -203,15 +272,14 @@ def _call_forward( msg = "WCS.forward_transform is not implemented." raise NotImplementedError(msg) - # Validate that the input type matches what the transform expects - input_is_quantity = any(isinstance(a, u.Quantity) for a in args) - if not input_is_quantity and transform.uses_quantity: - args = self._add_units_input(args, from_frame) - if not transform.uses_quantity and input_is_quantity: - args = self._remove_units_input(args, from_frame) - - return transform( - *args, with_bounding_box=with_bounding_box, fill_value=fill_value, **kwargs + return self._evaluate_transform( + transform, + from_frame, + to_frame, + *args, + with_bounding_box=with_bounding_box, + fill_value=fill_value, + **kwargs, ) def in_image(self, *args, **kwargs): @@ -334,16 +402,12 @@ def _call_backward( args = self.outside_footprint(args) if transform is not None: - # Validate that the input type matches what the transform expects - input_is_quantity = any(isinstance(a, u.Quantity) for a in args) - if not input_is_quantity and transform.uses_quantity: - args = self._add_units_input(args, self.output_frame) - if not transform.uses_quantity and input_is_quantity: - args = self._remove_units_input(args, self.output_frame) - # remove iterative inverse-specific keyword arguments: akwargs = {k: v for k, v in kwargs.items() if k not in _ITER_INV_KWARGS} - result = transform( + result = self._evaluate_transform( + transform, + self.output_frame, + self.input_frame, *args, with_bounding_box=with_bounding_box, fill_value=fill_value, @@ -403,7 +467,9 @@ def outside_footprint(self, world_arrays): max_ax = axis_range[~m].min() outside = (coord > min_ax) & (coord < max_ax) else: - coord_ = self._remove_units_input([coord], self.output_frame)[0] + coord_ = self._remove_quantity_output( + world_arrays, self.output_frame + )[idim] outside = (coord_ < min_ax) | (coord_ > max_ax) if np.any(outside): if np.isscalar(coord): @@ -1202,7 +1268,6 @@ def footprint(self, bounding_box=None, center=False, axis_type="all"): """ def _order_clockwise(v): - v = [self._remove_units_input(vv, self.input_frame) for vv in v] return np.asarray( [ [v[0][0], v[1][0]], @@ -1226,7 +1291,9 @@ def _order_clockwise(v): bb = np.asarray([b.value for b in bb]) * bb[0].unit vertices = (bb,) elif all_spatial: - vertices = _order_clockwise(bb) + vertices = _order_clockwise( + [self._remove_units_input(b, self.input_frame) for b in bb] + ) else: vertices = np.array(list(itertools.product(*bb))).T diff --git a/pyproject.toml b/pyproject.toml index 961b0bed..8c1008d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "gwcs" description = "Generalized World Coordinate System" -requires-python = ">=3.10" +requires-python = ">=3.11" authors = [ { name = "gwcs developers", email = "help@stsci.edu" }, ] @@ -44,7 +44,6 @@ docs = [ "sphinx-copybutton", "pydata-sphinx-theme", "sphinx-asdf", - "tomli; python_version <'3.11'", ] test = [ "ci-watson>=0.3.0", @@ -175,3 +174,60 @@ ignore = [ "gwcs/converters/tests/*" = ["S101"] "docs/conf.py" = ["INP001", "ERA001"] "convert_schemas.py" = ["PTH"] + + +[tool.mypy] +files = "gwcs" +python_version = "3.11" +strict = true +enable_error_code = [ + "ignore-without-code", + "redundant-expr", + "truthy-bool", +] +exclude = [ + 'gwcs/tests/$', + 'gwcs/converters/tests/$', + 'gwcs/wcs/$', + 'gwcs/api.py', + 'gwcs/coordinate_frames.py', + 'gwcs/examples.py', + 'gwcs/extension.py', + 'gwcs/geometry.py', + 'gwcs/region.py', + 'gwcs/selector.py', + 'gwcs/spectroscopy.py', + 'gwcs/utils.py', + 'gwcs/wcstools.py', +] +warn_unreachable = true + +[[tool.mypy.overrides]] +module = [ + "astropy.coordinates.*", + "astropy.modeling.*", + # "astropy.table.*", + "astropy.time.*", + "astropy.units.*", + "astropy.utils.*", + "astropy.wcs.*", +] +follow_untyped_imports = true + +# For gradual adoption +[[tool.mypy.overrides]] +module = [ + "gwcs.tests.*", + "gwcs.converters.*", + "gwcs.wcs.*", + "gwcs.api.*", + "gwcs.examples.*", + "gwcs.extension.*", + "gwcs.geometry.*", + "gwcs.region.*", + "gwcs.selector.*", + "gwcs.spectroscopy.*", + "gwcs.utils.*", + "gwcs.wcstools.*", +] +ignore_errors = true