Skip to content

Commit

Permalink
type hints for derived quantities
Browse files Browse the repository at this point in the history
  • Loading branch information
dlakaplan committed Jul 9, 2024
1 parent c28fc87 commit 2d72c75
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 48 deletions.
1 change: 1 addition & 0 deletions CHANGELOG-unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ the released changes.
- `pint.models.chromatic_model.ChromaticCM` for a Taylor series representation of the variable-index chromatic delay.
- Whitened residuals (`white-res`) as a plotting axis in `pintk`
- `TOAs.get_Tspan()` method
- Type hints in `pint.derived_quantities`
### Fixed
- `pint.utils.split_swx()` to use updated `SolarWindDispersionX()` parameter naming convention
- Fix #1759 by changing order of comparison
Expand Down
129 changes: 81 additions & 48 deletions src/pint/derived_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import astropy.constants as const
import astropy.units as u
import numpy as np
from typing import Optional, List, Tuple, Union

import pint

Expand Down Expand Up @@ -33,7 +34,9 @@
@u.quantity_input(
p=[u.Hz, u.s], pd=[u.Hz / u.s, u.s / u.s], pdd=[u.Hz / u.s**2, u.s / u.s**2]
)
def p_to_f(p, pd, pdd=None):
def p_to_f(
p: u.Quantity, pd: u.Quantity, pdd: Optional[u.Quantity] = None
) -> Tuple[u.Quantity]:
"""Converts P, Pdot to F, Fdot (or vice versa)
Convert period, period derivative and period second
Expand Down Expand Up @@ -71,7 +74,11 @@ def p_to_f(p, pd, pdd=None):
if pdd == 0.0
else 2.0 * pd * pd / (p**3.0) - pdd / (p * p)
)
return [f, fd, fdd]
return (f, fd, fdd)


# alias for the above
f_to_p = p_to_f


@u.quantity_input(
Expand All @@ -80,7 +87,12 @@ def p_to_f(p, pd, pdd=None):
pdorfd=[u.Hz / u.s, u.s / u.s],
pdorfderr=[u.Hz / u.s, u.s / u.s],
)
def pferrs(porf, porferr, pdorfd=None, pdorfderr=None):
def pferrs(
porf: u.Quantity,
porferr: u.Quantity,
pdorfd: Optional[u.Quantity] = None,
pdorfderr: Optional[u.Quantity] = None,
) -> Tuple[u.Quantity]:
"""Convert P, Pdot to F, Fdot with uncertainties (or vice versa).
Calculate the period or frequency errors and
Expand Down Expand Up @@ -120,15 +132,16 @@ def pferrs(porf, porferr, pdorfd=None, pdorfderr=None):
return [1.0 / porf, porferr / porf**2.0]
forperr = porferr / porf**2.0
fdorpderr = np.sqrt(
(4.0 * pdorfd**2.0 * porferr**2.0) / porf**6.0
+ pdorfderr**2.0 / porf**4.0
(4.0 * pdorfd**2.0 * porferr**2.0) / porf**6.0 + pdorfderr**2.0 / porf**4.0
)
[forp, fdorpd] = p_to_f(porf, pdorfd)
return [forp, forperr, fdorpd, fdorpderr]
return (forp, forperr, fdorpd, fdorpderr)


@u.quantity_input(fo=u.Hz)
def pulsar_age(f: u.Hz, fdot: u.Hz / u.s, n=3, fo=1e99 * u.Hz):
@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s, fo=u.Hz)
def pulsar_age(
f: u, Quantity, fdot: u.Quantity, n: int = 3, fo: u.Quantity = 1e99 * u.Hz
) -> u.Quantity:
"""Compute pulsar characteristic age
Return the age of a pulsar given the spin frequency
Expand Down Expand Up @@ -170,8 +183,10 @@ def pulsar_age(f: u.Hz, fdot: u.Hz / u.s, n=3, fo=1e99 * u.Hz):
return (-f / ((n - 1.0) * fdot) * (1.0 - (f / fo) ** (n - 1.0))).to(u.yr)


@u.quantity_input(I=u.g * u.cm**2)
def pulsar_edot(f: u.Hz, fdot: u.Hz / u.s, I=1.0e45 * u.g * u.cm**2):
@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s, I=u.g * u.cm**2)
def pulsar_edot(
f: u.Quantity, fdot: u.Quantity, I: u.Quantity = 1.0e45 * u.g * u.cm**2
) -> u.Quantity:
"""Compute pulsar spindown energy loss rate
Return the pulsar `Edot` (:math:`\dot E`, in erg/s) given the spin frequency `f` and
Expand Down Expand Up @@ -206,8 +221,8 @@ def pulsar_edot(f: u.Hz, fdot: u.Hz / u.s, I=1.0e45 * u.g * u.cm**2):
return (-4.0 * np.pi**2 * I * f * fdot).to(u.erg / u.s)


@u.quantity_input
def pulsar_B(f: u.Hz, fdot: u.Hz / u.s):
@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s)
def pulsar_B(f: u.Quantity, fdot: u.Quantity) -> u.Quantity:
"""Compute pulsar surface magnetic field
Return the estimated pulsar surface magnetic field strength
Expand Down Expand Up @@ -241,8 +256,8 @@ def pulsar_B(f: u.Hz, fdot: u.Hz / u.s):
return 3.2e19 * u.G * np.sqrt(-fdot.to_value(u.Hz / u.s) / f.to_value(u.Hz) ** 3.0)


@u.quantity_input
def pulsar_B_lightcyl(f: u.Hz, fdot: u.Hz / u.s):
@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s)
def pulsar_B_lightcyl(f: u.Quantity, fdot: u.Quantity) -> u.Quantity:
"""Compute pulsar magnetic field at the light cylinder
Return the estimated pulsar magnetic field strength at the
Expand Down Expand Up @@ -283,8 +298,8 @@ def pulsar_B_lightcyl(f: u.Hz, fdot: u.Hz / u.s):
)


@u.quantity_input
def mass_funct(pb: u.d, x: u.cm):
@u.quantity_input(pb=u.d, x=u.cm)
def mass_funct(pb: u.Quantity, x: u.Quantity) -> u.Quantity:
"""Compute binary mass function from period and semi-major axis
Can handle scalar or array inputs.
Expand Down Expand Up @@ -324,8 +339,8 @@ def mass_funct(pb: u.d, x: u.cm):
return fm.to(u.solMass)


@u.quantity_input
def mass_funct2(mp: u.Msun, mc: u.Msun, i: u.deg):
@u.quantity_input(mp=u.Msun, mc=u.Msun, i=u.deg)
def mass_funct2(mp: u.Quantity, mc: u.Quantity, i: u.Quantity) -> u.Quantity:
"""Compute binary mass function from masses and inclination
Can handle scalar or array inputs.
Expand Down Expand Up @@ -369,8 +384,10 @@ def mass_funct2(mp: u.Msun, mc: u.Msun, i: u.deg):
return (mc * np.sin(i)) ** 3.0 / (mc + mp) ** 2.0


@u.quantity_input
def pulsar_mass(pb: u.d, x: u.cm, mc: u.Msun, i: u.deg):
@u.quantity_input(pb=u.d, x=u.cm, mc=u.Msun, i=u.deg)
def pulsar_mass(
pb: u.Quantity, x: u.Quantity, mc: u.Quantity, i: u.Quantity
) -> u.Quantity:
"""Compute pulsar mass from orbital parameters
Return the pulsar mass (in solar mass units) for a binary.
Expand Down Expand Up @@ -436,8 +453,13 @@ def pulsar_mass(pb: u.d, x: u.cm, mc: u.Msun, i: u.deg):
return ((-cb + np.sqrt(4 * massfunct * mc**3 * sini**3)) / (2 * ca)).to(u.Msun)


@u.quantity_input(inc=u.deg, mpsr=u.solMass)
def companion_mass(pb: u.d, x: u.cm, i=60.0 * u.deg, mp=1.4 * u.solMass):
@u.quantity_input(pb=u.d, x=u.cm, i=u.deg, mp=u.solMass)
def companion_mass(
pb: u.Quantity,
x: u.Quantity,
i: u.Quantity = 60.0 * u.deg,
mp: u.Quantity = 1.4 * u.solMass,
) -> u.Quantity:
"""Commpute the companion mass from the orbital parameters
Compute companion mass for a binary system from orbital mechanics,
Expand Down Expand Up @@ -515,17 +537,12 @@ def companion_mass(pb: u.d, x: u.cm, i=60.0 * u.deg, mp=1.4 * u.solMass):
# delta1 is always <0
# delta1 = 2 * b ** 3 - 9 * a * b * c + 27 * a ** 2 * d
delta1 = (
-2 * massfunct**3
- 18 * a * mp * massfunct**2
- 27 * a**2 * massfunct * mp**2
-2 * massfunct**3 - 18 * a * mp * massfunct**2 - 27 * a**2 * massfunct * mp**2
)
# Q**2 is always > 0, so this is never a problem
# in terms of complex numbers
# Q = np.sqrt(delta1**2 - 4*delta0**3)
Q = np.sqrt(
108 * a**3 * mp**3 * massfunct**3
+ 729 * a**4 * mp**4 * massfunct**2
)
Q = np.sqrt(108 * a**3 * mp**3 * massfunct**3 + 729 * a**4 * mp**4 * massfunct**2)
# this could be + or - Q
# pick the - branch since delta1 is <0 so that delta1 - Q is never near 0
Ccubed = 0.5 * (delta1 + Q)
Expand All @@ -540,8 +557,10 @@ def companion_mass(pb: u.d, x: u.cm, i=60.0 * u.deg, mp=1.4 * u.solMass):
return x1.to(u.Msun)


@u.quantity_input
def pbdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled):
@u.quantity_input(mp=u.Msun, mc=u.Msun, pb=u.d, e=u.dimensionless_unscaled)
def pbdot(
mp: u.Quantity, mc: u.Quantity, pb: u.Quantity, e: Union[float, u.Quantity]
) -> u.Quantity:
"""Post-Keplerian orbital decay pbdot, assuming general relativity.
pbdot (:math:`\dot P_B`) is the change in the binary orbital period
Expand Down Expand Up @@ -603,8 +622,13 @@ def pbdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled):
return value.to(u.s / u.s)


@u.quantity_input
def gamma(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled):
@u.quantity_input(mp=u.Msun, mc=u.Msun, pb=u.d, e=u.dimensionless_unscaled)
def gamma(
mp: u.Quantity,
mc: u.Quantity,
pb: u.Quantity,
e: Union[float, u.Quantity],
) -> u.Quantity:
"""Post-Keplerian time dilation and gravitational redshift gamma, assuming general relativity.
gamma (:math:`\gamma`) is the amplitude of the modification in arrival times caused by the varying
Expand Down Expand Up @@ -659,8 +683,13 @@ def gamma(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled):
return value.to(u.s)


@u.quantity_input
def omdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled):
@u.quantity_input(mp=u.Msun, mc=u.Msun, pb=u.d, e=u.dimensionless_unscaled)
def omdot(
mp: u.Quantity,
mc: u.Quantity,
pb: u.Quantity,
e: Union[float, u.Quantity],
) -> u.Quantity:
"""Post-Keplerian longitude of periastron precession rate omdot, assuming general relativity.
omdot (:math:`\dot \omega`) is the relativistic advance of periastron.
Expand Down Expand Up @@ -714,8 +743,8 @@ def omdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled):
return value.to(u.deg / u.yr, equivalencies=u.dimensionless_angles())


@u.quantity_input
def sini(mp: u.Msun, mc: u.Msun, pb: u.d, x: u.cm):
@u.quantity_input(mp=u.Msun, mc=u.Msun, pb=u.d, x=u.cm)
def sini(mp: u.Quantity, mc: u.Quantity, pb: u.Quantity, x: u.Quantity) -> u.Quantity:
"""Post-Keplerian sine of inclination, assuming general relativity.
Can handle scalar or array inputs.
Expand Down Expand Up @@ -768,8 +797,8 @@ def sini(mp: u.Msun, mc: u.Msun, pb: u.d, x: u.cm):
).decompose()


@u.quantity_input
def dr(mp: u.Msun, mc: u.Msun, pb: u.d):
@u.quantity_input(mp=u.Msun, mc=u.Msun, pb=u.d)
def dr(mp: u.Quantity, mc: u.Quantity, pb: u.Quantity) -> u.Quantity:
"""Post-Keplerian Roemer delay term
dr (:math:`\delta_r`) is part of the relativistic deformation of the orbit
Expand Down Expand Up @@ -818,8 +847,8 @@ def dr(mp: u.Msun, mc: u.Msun, pb: u.d):
).decompose()


@u.quantity_input
def dth(mp: u.Msun, mc: u.Msun, pb: u.d):
@u.quantity_input(mp=u.Msun, mc=u.Msun, pb=u.d)
def dth(mp: u.Quantity, mc: u.Quantity, pb: u.Quantity) -> u.Quantity:
"""Post-Keplerian Roemer delay term
dth (:math:`\delta_{\\theta}`) is part of the relativistic deformation of the orbit
Expand Down Expand Up @@ -868,8 +897,10 @@ def dth(mp: u.Msun, mc: u.Msun, pb: u.d):
).decompose()


@u.quantity_input
def omdot_to_mtot(omdot: u.deg / u.yr, pb: u.d, e: u.dimensionless_unscaled):
@u.quantity_input(omdot=u.deg / u.yr, pb=u.d, e=u.dimensionless_unscaled)
def omdot_to_mtot(
omdot: u.Quantity, pb: u.Quantity, e: Union[float, u.Quantity]
) -> u.Quantity:
"""Determine total mass from Post-Keplerian longitude of periastron precession rate omdot,
assuming general relativity.
Expand Down Expand Up @@ -931,7 +962,9 @@ def omdot_to_mtot(omdot: u.deg / u.yr, pb: u.d, e: u.dimensionless_unscaled):


@u.quantity_input(pb=u.d, mp=u.Msun, mc=u.Msun, i=u.deg)
def a1sini(mp, mc, pb, i=90 * u.deg):
def a1sini(
mp: u.Quantity, mc: u.Quantity, pb: u.Quantity, i: u.Quantity = 90 * u.deg
) -> u.Quantity:
"""Pulsar's semi-major axis.
The full semi-major axis is given by Kepler's third law. This is the
Expand Down Expand Up @@ -982,8 +1015,8 @@ def a1sini(mp, mc, pb, i=90 * u.deg):
).to(pint.ls)


@u.quantity_input
def shklovskii_factor(pmtot: u.mas / u.yr, D: u.kpc):
@u.quantity_input(pmtot=u.mas / u.yr, D=u.kpc)
def shklovskii_factor(pmtot: u.Quantity, D: u.Quantity) -> u.Quantity:
"""Compute magnitude of Shklovskii correction factor.
Computes the Shklovskii correction factor, as defined in Eq 8.12 of Lorimer & Kramer (2005) [10]_
Expand Down Expand Up @@ -1020,8 +1053,8 @@ def shklovskii_factor(pmtot: u.mas / u.yr, D: u.kpc):
return a_s


@u.quantity_input
def dispersion_slope(dm: pint.dmu):
@u.quantity_input(dm=pint.dmu)
def dispersion_slope(dm: u.Quantity) -> u.Quantity:
"""Compute the dispersion slope.
This is equal to DMconst * DM.
Expand Down

0 comments on commit 2d72c75

Please sign in to comment.