Skip to content

Commit

Permalink
add api tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nden committed Sep 29, 2018
1 parent ea72474 commit 522e973
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 22 deletions.
20 changes: 7 additions & 13 deletions gwcs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def world_axis_units(self):
specification document, units that do not follow this standard are still
allowed, but just not recommended).
"""
return [unit.to_string(format='vounit') for unit in self.output_frame.unit]
return tuple(unit.to_string(format='vounit') for unit in self.output_frame.unit)

def pixel_to_world_values(self, *pixel_arrays):
"""
Expand Down Expand Up @@ -111,7 +111,7 @@ def world_to_array_index_values(self, *world_arrays):
`~BaseLowLevelWCS.pixel_to_world_values`). The indices should be
returned as rounded integers.
"""
result = self.invert(*world_arrays[::-1], with_units=False)
result = self.invert(*world_arrays, with_units=False)[::-1]
return result # astype(int)

@property
Expand Down Expand Up @@ -157,7 +157,7 @@ def axis_correlation_matrix(self):
would be `True` and all other entries `False`.
"""

return separable.separability_matrix(self.forward_transform)[1]
return separable.separability_matrix(self.forward_transform)

def serialized_classes(self):
"""
Expand Down Expand Up @@ -201,25 +201,19 @@ def array_index_to_world(self, *index_arrays):
Convert array indices to world coordinates (represented by Astropy
objects).
"""
return self(*index_arrays[::-1], with_units=True)
return self(*(index_arrays[::-1]), with_units=True)

def world_to_pixel(self, *world_objects):
"""
Convert world coordinates to pixel values.
"""
args = self.output_frame.coordinates(*world_objects)

if len(world_objects) == 1:
args = [args]
if not self.forward_transform.uses_quantity:
args = self.output_frame.coordinate_to_quantity(*args)
args = utils.get_values(self.output_frame.unit, *args)
return self.invert(*args)
return self.invert(*world_objects)

def world_to_array_index(self, *world_objects):
"""
Convert world coordinates (represented by Astropy objects) to array
indices.
"""
return self.invert(*world_objects[::-1], with_units=True)
result = self.invert(*world_objects, with_units=True)[::-1]
return tuple([utils._toindex(r) for r in result])

22 changes: 13 additions & 9 deletions gwcs/coordinate_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from astropy import units as u
from astropy import utils as astutil
from astropy import coordinates as coord
from astropy.wcs.wcsapi.low_level_api import validate_physical_types
from astropy.wcs.wcsapi.low_level_api import (validate_physical_types,
VALID_UCDS)


__all__ = ['Frame2D', 'CelestialFrame', 'SpectralFrame', 'CompositeFrame',
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(self, naxes, axes_type, axes_order, reference_frame=None,
raise ValueError("Number of axis_physical_type does not match number of axes.")
else:
axis_physical_type = tuple(["custom:{}".format(t) for t in self._axes_type])
axis_physical_type = tuple("custom:{}".format(pht) for pht in axis_physical_type if pht not in VALID_UCDS)
self._axis_physical_type = axis_physical_type
if name is None:
self._name = self.__class__.__name__
Expand Down Expand Up @@ -182,9 +184,6 @@ def coordinate_to_quantity(self, *coords):

@property
def axis_physical_type(self):
#if self.naxes == 1:
# return self._axis_physical_type[0]
#else:
return self._axis_physical_type

def set_physical_type(self, type=None):
Expand Down Expand Up @@ -291,14 +290,19 @@ def set_physical_type(self, type):
"custom:pos.{}".format(type[1]))
else:
raise ValueError("Expected physical_type to be a tuple of length 2.")
if isinstance(self.reference_frame, coord.builtin_frames.BaseRADecFrame):
if isinstance(self.reference_frame, (coord.Galactic,
coord.Galactocentric)):
ph_type = "pos.galactic.lon", "pos.galactic.lat"
elif isinstance(self.reference_frame, (coord.GeocentricTrueEcliptic,
coord.GCRS,
coord.PrecessedGeocentric,
coord.ITRS)):
ph_type = "pos.bodyrc.lon", "pos.bodyrc.lat"
elif isinstance(self.reference_frame, coord.builtin_frames.BaseRADecFrame):
ph_type = "pos.eq.ra", "pos.eq.dec"
elif isinstance(self.reference_frame, coord.builtin_frames.BaseEclipticFrame):
ph_type = "pos.ecliptic.lon", "pos.ecliptic.lat"
elif isinstance(self.reference_frame, coord.Galactic):
ph_type = "pos.galactic.lon", "pos.galactic.lat"
elif isinstance(self.reference_frame, coord.GeocentricTrueEcliptic):
ph_type = "pos.bodyrc.lon", "pos.bodyrc.lat"

else:
ph_type = ("custom:{}".format(self.axes_type[0]),
"custom:{}".format(self.axes_type[1]))
Expand Down
185 changes: 185 additions & 0 deletions gwcs/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Tests the API defined in astropy APE 14 (https://doi.org/10.5281/zenodo.1188875).
"""
import numpy as np
from numpy.testing import assert_allclose, assert_array_equal

from astropy import coordinates as coord
from astropy.modeling import models
from astropy import units as u

from .. import wcs
from .. import coordinate_frames as cf

import pytest

# frames
sky_frame = cf.CelestialFrame(reference_frame=coord.ICRS(), axes_order=(0, 1))
detector = cf.Frame2D(name='detector', axes_order=(0, 1))

spec1 = cf.SpectralFrame(name='freq', unit=[u.Hz, ], axes_order=(2, ))
spec2 = cf.SpectralFrame(name='wave', unit=[u.m, ], axes_order=(2, ), axes_names=('lambda', ))

comp1 = cf.CompositeFrame([sky_frame, spec1])

# transforms

m1 = models.Shift(1) & models.Shift(2)
m2 = models.Scale(2)
m = m1 & m2

pipe = [(detector, m1),
(sky_frame, None)
]

example_wcs = wcs.WCS(pipe)

def create_example_wcs():
example_wcs = [wcs.WCS([(detector, m1),
(sky_frame, None)]),
wcs.WCS([(detector, m2),
(spec1, None)]),
wcs.WCS([(detector, m),
(comp1, None)])
]

pixel_world_ndim = [(2, 2), (2, 1), (2, 3)]
physical_types = [("pos.eq.ra", "pos.eq.dec"), ("em.freq",), ("pos.eq.ra", "pos.eq.dec", "em.freq")]
world_units = [("deg", "deg"), ("Hz",), ("deg", "deg", "Hz")]

return example_wcs, pixel_world_ndim, physical_types, world_units

# x, y inputs - scalar and array
x, y = 1, 2
xarr, yarr = np.ones((3, 4)), np.ones((3, 4)) + 1

# ra, dec inputs - scalar, arrays and SkyCoord objects
ra, dec = 2, 4
sky = coord.SkyCoord(ra * u.deg, dec*u.deg, frame = sky_frame.reference_frame)
raarr = np.ones((3, 4)) * ra
decarr = np.ones((3, 4)) * dec
skyarr = coord.SkyCoord(raarr * u.deg, decarr*u.deg, frame = sky_frame.reference_frame)

ex_wcs, dims, physical_types, world_units = create_example_wcs()

@pytest.mark.parametrize(("wcsobj", "ndims"), zip(ex_wcs, dims))
def test_pixel_n_dim(wcsobj, ndims):
assert wcsobj.pixel_n_dim == ndims[0]


@pytest.mark.parametrize(("wcsobj", "ndims"), zip(ex_wcs, dims))
def test_world_n_dim(wcsobj, ndims):
assert wcsobj.world_n_dim == ndims[1]


@pytest.mark.parametrize(("wcsobj", "physical_types"), zip(ex_wcs, physical_types))
def test_world_axis_physical_types(wcsobj, physical_types):
assert wcsobj.world_axis_physical_types == physical_types


@pytest.mark.parametrize(("wcsobj", "world_units"), zip(ex_wcs, world_units))
def test_world_axis_units(wcsobj, world_units):
assert wcsobj.world_axis_units == world_units


@pytest.mark.parametrize(("x", "y"), zip((x, xarr), (y, yarr)))
def test_pixel_to_world_values(x, y):
wcsobj = example_wcs
assert_allclose(wcsobj.pixel_to_world_values(x, y), wcsobj(x, y, with_units=False))


@pytest.mark.parametrize(("x", "y"), zip((x, xarr), (y, yarr)))
def test_array_index_to_world_values(x, y):
wcsobj = example_wcs
assert_allclose(wcsobj.array_index_to_world_values(x, y), wcsobj(y, x, with_units=False))


@pytest.mark.parametrize(("sky", "ra", "dec"), zip((sky, skyarr), (ra, raarr), (dec, decarr)))
def test_world_to_pixel_values(sky, ra, dec):
wcsobj = example_wcs
assert_allclose(wcsobj.world_to_pixel_values(sky), wcsobj.invert(ra, dec, with_units=False))


@pytest.mark.parametrize(("sky", "ra", "dec"), zip((sky, skyarr), (ra, raarr), (dec, decarr)))
def test_world_to_array_index_values(sky, ra, dec):
wcsobj = example_wcs
assert_allclose(wcsobj.world_to_array_index_values(sky),
wcsobj.invert(ra, dec, with_units=False)[::-1])


def test_world_axis_object_components():
wcsobj = example_wcs
with pytest.raises(NotImplementedError):
wcsobj.world_axis_object_components()


def test_world_axis_object_classes():
wcsobj = example_wcs
with pytest.raises(NotImplementedError):
wcsobj.world_axis_object_classes()


def test_array_shape():
wcsobj = example_wcs
assert wcsobj.array_shape is None

wcsobj.array_shape = (2040, 1020)
assert wcsobj.array_shape is (2040, 1020)


def test_pixel_bounds():
wcsobj = example_wcs
assert wcsobj.pixel_bounds is None

wcsobj.bounding_box = ((-0.5, 2039.5), (-0.5, 1019.5))
assert_array_equal(wcsobj.pixel_bounds, wcsobj.bounding_box)


def test_axis_correlation_matrix():
wcsobj = example_wcs
assert_array_equal(wcsobj.axis_correlation_matrix, np.identity(2))


def test_serialized_classes():
wcsobj = example_wcs
assert wcsobj.serialized_classes() == False


def test_low_level_wcs():
wcsobj = example_wcs
assert id(wcsobj.low_level_wcs()) == id(wcsobj)


def test_pixel_to_world():
wcsobj = example_wcs
comp = wcsobj(x, y, with_units=True)
comp = wcsobj.output_frame.coordinates(comp)
result = wcsobj.pixel_to_world(x, y)
assert isinstance(comp, coord.SkyCoord)
assert isinstance(result, coord.SkyCoord)
assert_allclose(comp.data.lon, result.data.lon)
assert_allclose(comp.data.lat, result.data.lat)


def test_array_index_to_world():
wcsobj = example_wcs
comp = wcsobj(x, y, with_units=True)
comp = wcsobj.output_frame.coordinates(comp)
result = wcsobj.array_index_to_world(y, x)
assert isinstance(comp, coord.SkyCoord)
assert isinstance(result, coord.SkyCoord)
assert_allclose(comp.data.lon, result.data.lon)
assert_allclose(comp.data.lat, result.data.lat)


def test_world_to_pixel():
wcsobj = example_wcs
assert_allclose(wcsobj.world_to_pixel(sky), wcsobj.invert(ra, dec, with_units=False))


def test_world_to_array_index():
wcsobj = example_wcs
assert_allclose(wcsobj.world_to_array_index(sky), wcsobj.invert(ra, dec, with_units=False)[::-1])


48 changes: 48 additions & 0 deletions gwcs/tests/test_coordinate_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,13 @@ def test_coordinate_to_quantity_celestial(inp):
assert_quantity_allclose(lon, 10 * u.deg)
assert_quantity_allclose(lat, 20 * u.deg)

with pytest.raises(ValueError):
cel.coordinate_to_quantity(10*u.deg, 2*u.deg, 3*u.deg)

with pytest.raises(ValueError):
cel.coordinate_to_quantity((1, 2))


@pytest.mark.parametrize('inp', [
(100,),
(100 * u.nm,),
Expand Down Expand Up @@ -274,6 +280,19 @@ def test_coordinate_to_quantity_frame_2d():
assert_quantity_allclose(output, exp)


def test_coordinate_to_quantity_error():
frame = cf.Frame2D(unit=(u.one, u.arcsec))
with pytest.raises(ValueError):
frame.coordinate_to_quantity(1)

with pytest.raises(ValueError):
comp1.coordinate_to_quantity((1, 1), 2)

frame = cf.TemporalFrame(unit=u.s)
with pytest.raises(ValueError):
frame.coordinate_to_quantity(1)


def test_axis_physical_type():
assert icrs.axis_physical_type == ("pos.eq.ra", "pos.eq.dec")
assert spec1.axis_physical_type == ("em.freq",)
Expand All @@ -297,3 +316,32 @@ def test_axis_physical_type():

fr2d = cf.Frame2D(name='d', axis_physical_type=("pos.x", "pos.y"))
assert fr2d.axis_physical_type == ('custom:pos.x', 'custom:pos.y')

with pytest.raises(ValueError):
cf.CelestialFrame(reference_frame=coord.ICRS(), axis_physical_type=("pos.eq.ra",) )

fr = cf.CelestialFrame(reference_frame=coord.ICRS(), axis_physical_type=("ra", "dec"))
assert fr.axis_physical_type == ("custom:pos.ra", "custom:pos.dec")

fr = cf.CelestialFrame(reference_frame=coord.BarycentricTrueEcliptic())
assert fr.axis_physical_type == ('pos.ecliptic.lon', 'pos.ecliptic.lat')

frame = cf.CoordinateFrame(name='custom_frame', axes_type=("SPATIAL",), axes_order=(0,), axis_physical_type="length", axes_names="x", naxes=1)
assert frame.axis_physical_type == ("custom:length",)
frame = cf.CoordinateFrame(name='custom_frame', axes_type=("SPATIAL",), axes_order=(0,), axis_physical_type=("length",), axes_names="x", naxes=1)
assert frame.axis_physical_type == ("custom:length",)
with pytest.raises(ValueError):
cf.CoordinateFrame(name='custom_frame', axes_type=("SPATIAL",), axes_order=(0,), axis_physical_type=("length", "length"), naxes=1)


def test_base_frame():
with pytest.raises(ValueError):
cf.CoordinateFrame(name='custom_frame',
axes_type=("SPATIAL",),
naxes=1, axes_order=(0,),
axes_names=("x", "y"))
frame = cf.CoordinateFrame(name='custom_frame', axes_type=("SPATIAL",), axes_order=(0,), axes_names="x", naxes=1)
assert frame.naxes == 1
assert frame.axes_names == ("x",)

frame.coordinate_to_quantity(1, 2)

0 comments on commit 522e973

Please sign in to comment.