Skip to content

Commit

Permalink
added type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
dlakaplan committed Feb 12, 2024
1 parent 1f54b53 commit 34f0c08
Showing 1 changed file with 99 additions and 56 deletions.
155 changes: 99 additions & 56 deletions src/pint/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""
from collections import OrderedDict
from copy import deepcopy
from typing import Optional, Union, List, Dict
import pathlib

import astropy.units as u
import numpy as np
Expand All @@ -22,7 +24,14 @@
]


def zero_residuals(ts, model, *, subtract_mean=True, maxiter=10, tolerance=None):
def zero_residuals(
ts: pint.toa.TOAs,
model: pint.models.timing_model.TimingModel,
*,
subtract_mean: bool = True,
maxiter: int = 10,
tolerance: Optional[u.Quantity] = None,
):
"""Use a model to adjust a TOAs object, setting residuals to 0 iteratively.
Parameters
Expand Down Expand Up @@ -62,7 +71,11 @@ def zero_residuals(ts, model, *, subtract_mean=True, maxiter=10, tolerance=None)
)


def get_fake_toa_clock_versions(model, include_bipm=False, include_gps=True):
def get_fake_toa_clock_versions(
model: pint.models.timing_model.TimingModel,
include_bipm: bool = False,
include_gps: bool = True,
) -> dict:
"""Get the clock settings (corrections, etc) for fake TOAs
Parameters
Expand All @@ -75,6 +88,10 @@ def get_fake_toa_clock_versions(model, include_bipm=False, include_gps=True):
include_gps : bool, optional
Whether or not to disable UTC(GPS)->UTC clock correction
(see :class:`pint.observatory.topo_obs.TopoObs`)
Returns
-------
dict
"""
bipm_version = bipm_default
if model["CLOCK"].value is not None:
Expand Down Expand Up @@ -109,13 +126,13 @@ def get_fake_toa_clock_versions(model, include_bipm=False, include_gps=True):


def make_fake_toas(
ts,
model,
add_noise=False,
add_correlated_noise=False,
name="fake",
subtract_mean=True,
):
ts: pint.toa.TOAs,
model: pint.models.timing_model.TimingModel,
add_noise: bool = False,
add_correlated_noise: bool = False,
name: str = "fake",
subtract_mean: bool = True,
) -> pint.toa.TOAs:
"""Make toas from an array of times
Can include alternating frequencies if fed an array of frequencies,
Expand Down Expand Up @@ -166,8 +183,23 @@ def make_fake_toas(
return tsim


def update_fake_dms(model, ts, dm_error, add_noise):
"""Update simulated wideband DM information in TOAs."""
def update_fake_dms(
model: pint.models.timing_model.TimingModel,
ts: pint.toa.TOAs,
dm_error: u.Quantity,
add_noise: bool,
) -> pint.toa.TOAs:
"""Update simulated wideband DM information in TOAs.
Parameters
----------
model: pint.models.timing_model.TimingModel
ts : pint.toa.TOAs
Input TOAs
dm_error: u.Quantity
add_noise : bool, optional
Add noise to the DMs (otherwise `dm_error` just populates the column)
"""
toas = deepcopy(ts)

dm_errors = dm_error * np.ones(len(toas))
Expand All @@ -187,25 +219,25 @@ def update_fake_dms(model, ts, dm_error, add_noise):


def make_fake_toas_uniform(
startMJD,
endMJD,
ntoas,
model,
fuzz=0,
freq=1400 * u.MHz,
obs="GBT",
error=1 * u.us,
add_noise=False,
add_correlated_noise=False,
wideband=False,
wideband_dm_error=1e-4 * pint.dmu,
name="fake",
include_bipm=False,
include_gps=True,
multi_freqs_in_epoch=False,
flags=None,
subtract_mean=True,
):
startMJD: float | u.Quantity | time.Time,
endMJD: float | u.Quantity | time.Time,
ntoas: int,
model: pint.models.timing_model.TimingModel,
fuzz: u.Quantity = 0,
freq: u.Quantity = 1400 * u.MHz,
obs: str = "GBT",
error: u.Quantity = 1 * u.us,
add_noise: bool = False,
add_correlated_noise: bool = False,
wideband: bool = False,
wideband_dm_error: u.Quantity = 1e-4 * pint.dmu,
name: str = "fake",
include_bipm: bool = False,
include_gps: bool = True,
multi_freqs_in_epoch: bool = False,
flags: Optional[dict] = None,
subtract_mean: bool = True,
) -> pint.toa.TOAs:
"""Simulate uniformly spaced TOAs.
Parameters
Expand Down Expand Up @@ -329,22 +361,22 @@ def make_fake_toas_uniform(


def make_fake_toas_fromMJDs(
MJDs,
model,
freq=1400 * u.MHz,
obs="GBT",
error=1 * u.us,
add_noise=False,
add_correlated_noise=False,
wideband=False,
wideband_dm_error=1e-4 * pint.dmu,
name="fake",
include_bipm=False,
include_gps=True,
multi_freqs_in_epoch=False,
flags=None,
subtract_mean=True,
):
MJDs: u.Quantity | time.Time | np.ndarray,
model: pint.models.timing_model.TimingModel,
freq: u.Quantity = 1400 * u.MHz,
obs: str = "GBT",
error: u.Quantity = 1 * u.us,
add_noise: bool = False,
add_correlated_noise: bool = False,
wideband: bool = False,
wideband_dm_error: u.Quantity = 1e-4 * pint.dmu,
name: str = "fake",
include_bipm: bool = False,
include_gps: bool = True,
multi_freqs_in_epoch: bool = False,
flags: Optional[dict] = None,
subtract_mean: bool = True,
) -> pint.toa.TOAs:
"""Simulate TOAs from a list of MJDs
Parameters
Expand Down Expand Up @@ -464,13 +496,13 @@ def make_fake_toas_fromMJDs(


def make_fake_toas_fromtim(
timfile,
model,
add_noise=False,
add_correlated_noise=False,
name="fake",
subtract_mean=True,
):
timfile: str | List[str] | pathlib.Path,
model: pint.models.timing_model.TimingModel,
add_noise: bool = False,
add_correlated_noise: bool = False,
name: str = "fake",
subtract_mean: bool = True,
) -> pint.toa.TOAs:
"""Simulate fake TOAs with the same times as an input tim file
Parameters
Expand Down Expand Up @@ -526,8 +558,13 @@ def make_fake_toas_fromtim(


def calculate_random_models(
fitter, toas, Nmodels=100, keep_models=True, return_time=False, params="all"
):
fitter: pint.fitter.Fitter,
toas: pint.toa.TOAs,
Nmodels: int = 100,
keep_models: bool = True,
return_time: bool = False,
params: str = "all",
) -> (np.ndarray, Optional[list]):
"""
Calculates random models based on the covariance matrix of the `fitter` object.
Expand Down Expand Up @@ -650,7 +687,13 @@ def calculate_random_models(
return (dphase, random_models) if keep_models else dphase


def _get_freqs_and_times(start, end, ntoas, freqs, multi_freqs_in_epoch=True):
def _get_freqs_and_times(
start: float | u.Quantity | time.Time,
end: float | u.Quantity | time.Time,
ntoas: int,
freqs: u.Quantity,
multi_freqs_in_epoch: bool = True,
) -> (float | u.Quantity | time.Time, np.ndarray):
freqs = np.atleast_1d(freqs)
assert (
len(freqs.shape) == 1 and len(freqs) <= ntoas
Expand Down

0 comments on commit 34f0c08

Please sign in to comment.