diff --git a/src/pymap3d/eci.py b/src/pymap3d/eci.py index 1ee83e1..75d8111 100644 --- a/src/pymap3d/eci.py +++ b/src/pymap3d/eci.py @@ -4,7 +4,7 @@ from datetime import datetime -from numpy import array, atleast_1d, column_stack, cos, empty, sin +import numpy as np try: import astropy.units as u @@ -52,16 +52,19 @@ def eci2ecef(x, y, z, time: datetime) -> tuple: y_ecef = itrs.y.value z_ecef = itrs.z.value except NameError: - x = atleast_1d(x) - y = atleast_1d(y) - z = atleast_1d(z) - gst = atleast_1d(greenwichsrt(juliandate(time))) - assert x.shape == y.shape == z.shape, f"shape mismatch: x: ${x.shape} y: {y.shape} z: {z.shape}" - if gst.size > 1: - assert x.size == gst.size, f"shape mismatch: x: {x.shape} gst: {gst.shape}" - - eci = column_stack((x.ravel(), y.ravel(), z.ravel())) - ecef = empty((x.size, 3)) + x = np.atleast_1d(x) + y = np.atleast_1d(y) + z = np.atleast_1d(z) + gst = np.atleast_1d(greenwichsrt(juliandate(time))) + assert ( + x.shape == y.shape == z.shape + ), f"shape mismatch: x: ${x.shape} y: {y.shape} z: {z.shape}" + if gst.size == 1 and x.size != 1: + gst = np.broadcast_to(gst, x.shape) + assert x.size == gst.size, f"shape mismatch: x: {x.shape} gst: {gst.shape}" + + eci = np.column_stack((x.ravel(), y.ravel(), z.ravel())) + ecef = np.empty((x.size, 3)) for i in range(eci.shape[0]): ecef[i, :] = R3(gst[i]) @ eci[i, :].T @@ -109,15 +112,15 @@ def ecef2eci(x, y, z, time: datetime) -> tuple: y_eci = eci.y.value z_eci = eci.z.value except NameError: - x = atleast_1d(x) - y = atleast_1d(y) - z = atleast_1d(z) - gst = atleast_1d(greenwichsrt(juliandate(time))) + x = np.atleast_1d(x) + y = np.atleast_1d(y) + z = np.atleast_1d(z) + gst = np.atleast_1d(greenwichsrt(juliandate(time))) assert x.shape == y.shape == z.shape assert x.size == gst.size - ecef = column_stack((x.ravel(), y.ravel(), z.ravel())) - eci = empty((x.size, 3)) + ecef = np.column_stack((x.ravel(), y.ravel(), z.ravel())) + eci = np.empty((x.size, 3)) for i in range(x.size): eci[i, :] = R3(gst[i]).T @ ecef[i, :] @@ -130,4 +133,4 @@ def ecef2eci(x, y, z, time: datetime) -> tuple: def R3(x: float): """Rotation matrix for ECI""" - return array([[cos(x), sin(x), 0], [-sin(x), cos(x), 0], [0, 0, 1]]) + return np.array([[np.cos(x), np.sin(x), 0], [-np.sin(x), np.cos(x), 0], [0, 0, 1]])