From 2d72c75fafb00a449945fa3b0df91d7ad5105e6d Mon Sep 17 00:00:00 2001 From: David Kaplan Date: Tue, 9 Jul 2024 09:46:21 -0500 Subject: [PATCH] type hints for derived quantities --- CHANGELOG-unreleased.md | 1 + src/pint/derived_quantities.py | 129 +++++++++++++++++++++------------ 2 files changed, 82 insertions(+), 48 deletions(-) diff --git a/CHANGELOG-unreleased.md b/CHANGELOG-unreleased.md index d8ff0a4b7..77def071d 100644 --- a/CHANGELOG-unreleased.md +++ b/CHANGELOG-unreleased.md @@ -39,6 +39,7 @@ the released changes. - `pint.models.chromatic_model.ChromaticCM` for a Taylor series representation of the variable-index chromatic delay. - Whitened residuals (`white-res`) as a plotting axis in `pintk` - `TOAs.get_Tspan()` method +- Type hints in `pint.derived_quantities` ### Fixed - `pint.utils.split_swx()` to use updated `SolarWindDispersionX()` parameter naming convention - Fix #1759 by changing order of comparison diff --git a/src/pint/derived_quantities.py b/src/pint/derived_quantities.py index 104b73d8d..a69331f13 100644 --- a/src/pint/derived_quantities.py +++ b/src/pint/derived_quantities.py @@ -4,6 +4,7 @@ import astropy.constants as const import astropy.units as u import numpy as np +from typing import Optional, List, Tuple, Union import pint @@ -33,7 +34,9 @@ @u.quantity_input( p=[u.Hz, u.s], pd=[u.Hz / u.s, u.s / u.s], pdd=[u.Hz / u.s**2, u.s / u.s**2] ) -def p_to_f(p, pd, pdd=None): +def p_to_f( + p: u.Quantity, pd: u.Quantity, pdd: Optional[u.Quantity] = None +) -> Tuple[u.Quantity]: """Converts P, Pdot to F, Fdot (or vice versa) Convert period, period derivative and period second @@ -71,7 +74,11 @@ def p_to_f(p, pd, pdd=None): if pdd == 0.0 else 2.0 * pd * pd / (p**3.0) - pdd / (p * p) ) - return [f, fd, fdd] + return (f, fd, fdd) + + +# alias for the above +f_to_p = p_to_f @u.quantity_input( @@ -80,7 +87,12 @@ def p_to_f(p, pd, pdd=None): pdorfd=[u.Hz / u.s, u.s / u.s], pdorfderr=[u.Hz / u.s, u.s / u.s], ) -def pferrs(porf, porferr, pdorfd=None, pdorfderr=None): +def pferrs( + porf: u.Quantity, + porferr: u.Quantity, + pdorfd: Optional[u.Quantity] = None, + pdorfderr: Optional[u.Quantity] = None, +) -> Tuple[u.Quantity]: """Convert P, Pdot to F, Fdot with uncertainties (or vice versa). Calculate the period or frequency errors and @@ -120,15 +132,16 @@ def pferrs(porf, porferr, pdorfd=None, pdorfderr=None): return [1.0 / porf, porferr / porf**2.0] forperr = porferr / porf**2.0 fdorpderr = np.sqrt( - (4.0 * pdorfd**2.0 * porferr**2.0) / porf**6.0 - + pdorfderr**2.0 / porf**4.0 + (4.0 * pdorfd**2.0 * porferr**2.0) / porf**6.0 + pdorfderr**2.0 / porf**4.0 ) [forp, fdorpd] = p_to_f(porf, pdorfd) - return [forp, forperr, fdorpd, fdorpderr] + return (forp, forperr, fdorpd, fdorpderr) -@u.quantity_input(fo=u.Hz) -def pulsar_age(f: u.Hz, fdot: u.Hz / u.s, n=3, fo=1e99 * u.Hz): +@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s, fo=u.Hz) +def pulsar_age( + f: u, Quantity, fdot: u.Quantity, n: int = 3, fo: u.Quantity = 1e99 * u.Hz +) -> u.Quantity: """Compute pulsar characteristic age Return the age of a pulsar given the spin frequency @@ -170,8 +183,10 @@ def pulsar_age(f: u.Hz, fdot: u.Hz / u.s, n=3, fo=1e99 * u.Hz): return (-f / ((n - 1.0) * fdot) * (1.0 - (f / fo) ** (n - 1.0))).to(u.yr) -@u.quantity_input(I=u.g * u.cm**2) -def pulsar_edot(f: u.Hz, fdot: u.Hz / u.s, I=1.0e45 * u.g * u.cm**2): +@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s, I=u.g * u.cm**2) +def pulsar_edot( + f: u.Quantity, fdot: u.Quantity, I: u.Quantity = 1.0e45 * u.g * u.cm**2 +) -> u.Quantity: """Compute pulsar spindown energy loss rate Return the pulsar `Edot` (:math:`\dot E`, in erg/s) given the spin frequency `f` and @@ -206,8 +221,8 @@ def pulsar_edot(f: u.Hz, fdot: u.Hz / u.s, I=1.0e45 * u.g * u.cm**2): return (-4.0 * np.pi**2 * I * f * fdot).to(u.erg / u.s) -@u.quantity_input -def pulsar_B(f: u.Hz, fdot: u.Hz / u.s): +@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s) +def pulsar_B(f: u.Quantity, fdot: u.Quantity) -> u.Quantity: """Compute pulsar surface magnetic field Return the estimated pulsar surface magnetic field strength @@ -241,8 +256,8 @@ def pulsar_B(f: u.Hz, fdot: u.Hz / u.s): return 3.2e19 * u.G * np.sqrt(-fdot.to_value(u.Hz / u.s) / f.to_value(u.Hz) ** 3.0) -@u.quantity_input -def pulsar_B_lightcyl(f: u.Hz, fdot: u.Hz / u.s): +@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s) +def pulsar_B_lightcyl(f: u.Quantity, fdot: u.Quantity) -> u.Quantity: """Compute pulsar magnetic field at the light cylinder Return the estimated pulsar magnetic field strength at the @@ -283,8 +298,8 @@ def pulsar_B_lightcyl(f: u.Hz, fdot: u.Hz / u.s): ) -@u.quantity_input -def mass_funct(pb: u.d, x: u.cm): +@u.quantity_input(pb=u.d, x=u.cm) +def mass_funct(pb: u.Quantity, x: u.Quantity) -> u.Quantity: """Compute binary mass function from period and semi-major axis Can handle scalar or array inputs. @@ -324,8 +339,8 @@ def mass_funct(pb: u.d, x: u.cm): return fm.to(u.solMass) -@u.quantity_input -def mass_funct2(mp: u.Msun, mc: u.Msun, i: u.deg): +@u.quantity_input(mp=u.Msun, mc=u.Msun, i=u.deg) +def mass_funct2(mp: u.Quantity, mc: u.Quantity, i: u.Quantity) -> u.Quantity: """Compute binary mass function from masses and inclination Can handle scalar or array inputs. @@ -369,8 +384,10 @@ def mass_funct2(mp: u.Msun, mc: u.Msun, i: u.deg): return (mc * np.sin(i)) ** 3.0 / (mc + mp) ** 2.0 -@u.quantity_input -def pulsar_mass(pb: u.d, x: u.cm, mc: u.Msun, i: u.deg): +@u.quantity_input(pb=u.d, x=u.cm, mc=u.Msun, i=u.deg) +def pulsar_mass( + pb: u.Quantity, x: u.Quantity, mc: u.Quantity, i: u.Quantity +) -> u.Quantity: """Compute pulsar mass from orbital parameters Return the pulsar mass (in solar mass units) for a binary. @@ -436,8 +453,13 @@ def pulsar_mass(pb: u.d, x: u.cm, mc: u.Msun, i: u.deg): return ((-cb + np.sqrt(4 * massfunct * mc**3 * sini**3)) / (2 * ca)).to(u.Msun) -@u.quantity_input(inc=u.deg, mpsr=u.solMass) -def companion_mass(pb: u.d, x: u.cm, i=60.0 * u.deg, mp=1.4 * u.solMass): +@u.quantity_input(pb=u.d, x=u.cm, i=u.deg, mp=u.solMass) +def companion_mass( + pb: u.Quantity, + x: u.Quantity, + i: u.Quantity = 60.0 * u.deg, + mp: u.Quantity = 1.4 * u.solMass, +) -> u.Quantity: """Commpute the companion mass from the orbital parameters Compute companion mass for a binary system from orbital mechanics, @@ -515,17 +537,12 @@ def companion_mass(pb: u.d, x: u.cm, i=60.0 * u.deg, mp=1.4 * u.solMass): # delta1 is always <0 # delta1 = 2 * b ** 3 - 9 * a * b * c + 27 * a ** 2 * d delta1 = ( - -2 * massfunct**3 - - 18 * a * mp * massfunct**2 - - 27 * a**2 * massfunct * mp**2 + -2 * massfunct**3 - 18 * a * mp * massfunct**2 - 27 * a**2 * massfunct * mp**2 ) # Q**2 is always > 0, so this is never a problem # in terms of complex numbers # Q = np.sqrt(delta1**2 - 4*delta0**3) - Q = np.sqrt( - 108 * a**3 * mp**3 * massfunct**3 - + 729 * a**4 * mp**4 * massfunct**2 - ) + Q = np.sqrt(108 * a**3 * mp**3 * massfunct**3 + 729 * a**4 * mp**4 * massfunct**2) # this could be + or - Q # pick the - branch since delta1 is <0 so that delta1 - Q is never near 0 Ccubed = 0.5 * (delta1 + Q) @@ -540,8 +557,10 @@ def companion_mass(pb: u.d, x: u.cm, i=60.0 * u.deg, mp=1.4 * u.solMass): return x1.to(u.Msun) -@u.quantity_input -def pbdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): +@u.quantity_input(mp=u.Msun, mc=u.Msun, pb=u.d, e=u.dimensionless_unscaled) +def pbdot( + mp: u.Quantity, mc: u.Quantity, pb: u.Quantity, e: Union[float, u.Quantity] +) -> u.Quantity: """Post-Keplerian orbital decay pbdot, assuming general relativity. pbdot (:math:`\dot P_B`) is the change in the binary orbital period @@ -603,8 +622,13 @@ def pbdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): return value.to(u.s / u.s) -@u.quantity_input -def gamma(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): +@u.quantity_input(mp=u.Msun, mc=u.Msun, pb=u.d, e=u.dimensionless_unscaled) +def gamma( + mp: u.Quantity, + mc: u.Quantity, + pb: u.Quantity, + e: Union[float, u.Quantity], +) -> u.Quantity: """Post-Keplerian time dilation and gravitational redshift gamma, assuming general relativity. gamma (:math:`\gamma`) is the amplitude of the modification in arrival times caused by the varying @@ -659,8 +683,13 @@ def gamma(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): return value.to(u.s) -@u.quantity_input -def omdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): +@u.quantity_input(mp=u.Msun, mc=u.Msun, pb=u.d, e=u.dimensionless_unscaled) +def omdot( + mp: u.Quantity, + mc: u.Quantity, + pb: u.Quantity, + e: Union[float, u.Quantity], +) -> u.Quantity: """Post-Keplerian longitude of periastron precession rate omdot, assuming general relativity. omdot (:math:`\dot \omega`) is the relativistic advance of periastron. @@ -714,8 +743,8 @@ def omdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): return value.to(u.deg / u.yr, equivalencies=u.dimensionless_angles()) -@u.quantity_input -def sini(mp: u.Msun, mc: u.Msun, pb: u.d, x: u.cm): +@u.quantity_input(mp=u.Msun, mc=u.Msun, pb=u.d, x=u.cm) +def sini(mp: u.Quantity, mc: u.Quantity, pb: u.Quantity, x: u.Quantity) -> u.Quantity: """Post-Keplerian sine of inclination, assuming general relativity. Can handle scalar or array inputs. @@ -768,8 +797,8 @@ def sini(mp: u.Msun, mc: u.Msun, pb: u.d, x: u.cm): ).decompose() -@u.quantity_input -def dr(mp: u.Msun, mc: u.Msun, pb: u.d): +@u.quantity_input(mp=u.Msun, mc=u.Msun, pb=u.d) +def dr(mp: u.Quantity, mc: u.Quantity, pb: u.Quantity) -> u.Quantity: """Post-Keplerian Roemer delay term dr (:math:`\delta_r`) is part of the relativistic deformation of the orbit @@ -818,8 +847,8 @@ def dr(mp: u.Msun, mc: u.Msun, pb: u.d): ).decompose() -@u.quantity_input -def dth(mp: u.Msun, mc: u.Msun, pb: u.d): +@u.quantity_input(mp=u.Msun, mc=u.Msun, pb=u.d) +def dth(mp: u.Quantity, mc: u.Quantity, pb: u.Quantity) -> u.Quantity: """Post-Keplerian Roemer delay term dth (:math:`\delta_{\\theta}`) is part of the relativistic deformation of the orbit @@ -868,8 +897,10 @@ def dth(mp: u.Msun, mc: u.Msun, pb: u.d): ).decompose() -@u.quantity_input -def omdot_to_mtot(omdot: u.deg / u.yr, pb: u.d, e: u.dimensionless_unscaled): +@u.quantity_input(omdot=u.deg / u.yr, pb=u.d, e=u.dimensionless_unscaled) +def omdot_to_mtot( + omdot: u.Quantity, pb: u.Quantity, e: Union[float, u.Quantity] +) -> u.Quantity: """Determine total mass from Post-Keplerian longitude of periastron precession rate omdot, assuming general relativity. @@ -931,7 +962,9 @@ def omdot_to_mtot(omdot: u.deg / u.yr, pb: u.d, e: u.dimensionless_unscaled): @u.quantity_input(pb=u.d, mp=u.Msun, mc=u.Msun, i=u.deg) -def a1sini(mp, mc, pb, i=90 * u.deg): +def a1sini( + mp: u.Quantity, mc: u.Quantity, pb: u.Quantity, i: u.Quantity = 90 * u.deg +) -> u.Quantity: """Pulsar's semi-major axis. The full semi-major axis is given by Kepler's third law. This is the @@ -982,8 +1015,8 @@ def a1sini(mp, mc, pb, i=90 * u.deg): ).to(pint.ls) -@u.quantity_input -def shklovskii_factor(pmtot: u.mas / u.yr, D: u.kpc): +@u.quantity_input(pmtot=u.mas / u.yr, D=u.kpc) +def shklovskii_factor(pmtot: u.Quantity, D: u.Quantity) -> u.Quantity: """Compute magnitude of Shklovskii correction factor. Computes the Shklovskii correction factor, as defined in Eq 8.12 of Lorimer & Kramer (2005) [10]_ @@ -1020,8 +1053,8 @@ def shklovskii_factor(pmtot: u.mas / u.yr, D: u.kpc): return a_s -@u.quantity_input -def dispersion_slope(dm: pint.dmu): +@u.quantity_input(dm=pint.dmu) +def dispersion_slope(dm: u.Quantity) -> u.Quantity: """Compute the dispersion slope. This is equal to DMconst * DM.