From 452127227486b1cb2d87e945f9f3c582c7a29845 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Sat, 10 Feb 2024 17:16:07 +0000 Subject: [PATCH 01/28] Use cached_property and types --- src/pint/observatory/__init__.py | 122 +++++++++++++++++++------------ src/pint/observatory/topo_obs.py | 78 ++++++++++---------- 2 files changed, 116 insertions(+), 84 deletions(-) diff --git a/src/pint/observatory/__init__.py b/src/pint/observatory/__init__.py index 895d46411..2935613fe 100644 --- a/src/pint/observatory/__init__.py +++ b/src/pint/observatory/__init__.py @@ -21,14 +21,17 @@ necessary. """ -from copy import deepcopy import os import textwrap from collections import defaultdict +from collections.abc import Callable +from copy import deepcopy from io import StringIO from pathlib import Path +from typing import Optional, Union import astropy.coordinates +import astropy.time import astropy.units as u import numpy as np from astropy.coordinates import EarthLocation @@ -97,7 +100,7 @@ def _load_gps_clock(): ) -def _load_bipm_clock(bipm_version): +def _load_bipm_clock(bipm_version: str): bipm_version = bipm_version.lower() if bipm_version not in _bipm_clock_versions: try: @@ -136,34 +139,40 @@ class Observatory: position. """ + fullname: str + aliases: list[str] + include_gps: bool + include_bipm: bool + bipm_version: str + # This is a dict containing all defined Observatory instances, # keyed on standard observatory name. - _registry = {} + _registry: dict[str, "Observatory"] = {} # This is a dict mapping any defined aliases to the corresponding # standard name. - _alias_map = {} + _alias_map: dict[str, str] = {} def __init__( self, - name, - fullname=None, - aliases=None, - include_gps=True, - include_bipm=True, - bipm_version=bipm_default, - overwrite=False, + name: str, + fullname: Optional[str] = None, + aliases: Optional[list[str]] = None, + include_gps: bool = True, + include_bipm: bool = True, + bipm_version: str = bipm_default, + overwrite: bool = False, ): - self._name = name.lower() - self._aliases = ( + self._name: str = name.lower() + self._aliases: list[str] = ( list(set(map(str.lower, aliases))) if aliases is not None else [] ) if aliases is not None: Observatory._add_aliases(self, aliases) - self.fullname = fullname if fullname is not None else name - self.include_gps = include_gps - self.include_bipm = include_bipm - self.bipm_version = bipm_version + self.fullname: str = fullname if fullname is not None else name + self.include_gps: bool = include_gps + self.include_bipm: bool = include_bipm + self.bipm_version: str = bipm_version if name.lower() in Observatory._registry: if not overwrite: @@ -175,16 +184,18 @@ def __init__( Observatory._register(self, name) @classmethod - def _register(cls, obs, name): - """Add an observatory to the registry using the specified name - (which will be converted to lower case). If an existing observatory + def _register(cls, obs: "Observatory", name: str): + """Add an observatory to the registry using the specified name (which will be converted to lower case). + + If an existing observatory of the same name exists, it will be replaced with the new one. The Observatory instance's name attribute will be updated for - consistency.""" + consistency. + """ cls._registry[name.lower()] = obs @classmethod - def _add_aliases(cls, obs, aliases): + def _add_aliases(cls, obs: "Observatory", aliases: list[str]): """Add aliases for the specified Observatory. Aliases should be given as a list. If any of the new aliases are already in use, they will be replaced. Aliases are not checked against the @@ -196,14 +207,17 @@ def _add_aliases(cls, obs, aliases): cls._alias_map[a.lower()] = obs.name @staticmethod - def gps_correction(t, limits="warn"): + def gps_correction(t: astropy.time.Time, limits: str = "warn"): """Compute the GPS clock corrections for times t.""" log.info("Applying GPS to UTC clock correction (~few nanoseconds)") _load_gps_clock() + assert _gps_clock is not None return _gps_clock.evaluate(t, limits=limits) @staticmethod - def bipm_correction(t, bipm_version=bipm_default, limits="warn"): + def bipm_correction( + t: astropy.time.Time, bipm_version: str = bipm_default, limits: str = "warn" + ): """Compute the GPS clock corrections for times t.""" log.info(f"Applying TT(TAI) to TT({bipm_version}) clock correction (~27 us)") tt2tai = 32.184 * 1e6 * u.us @@ -214,7 +228,7 @@ def bipm_correction(t, bipm_version=bipm_default, limits="warn"): ) @classmethod - def clear_registry(cls): + def clear_registry(cls) -> None: """Clear registry for ground-based observatories.""" cls._registry = {} cls._alias_map = {} @@ -229,7 +243,7 @@ def names(cls): return cls._registry.keys() @classmethod - def names_and_aliases(cls): + def names_and_aliases(cls) -> dict[str, list[str]]: """List all observatories and their aliases""" import pint.observatory.topo_obs # noqa import pint.observatory.special_locations # noqa @@ -241,15 +255,15 @@ def names_and_aliases(cls): # setter methods that update the registries appropriately. @property - def name(self): + def name(self) -> str: return self._name @property - def aliases(self): + def aliases(self) -> list[str]: return self._aliases @classmethod - def get(cls, name): + def get(cls, name: str) -> "Observatory": """Returns the Observatory instance for the specified name/alias. If the name has not been defined, an error will be raised. Aside @@ -303,9 +317,12 @@ def get(cls, name): # Any which raise NotImplementedError below must be implemented in # derived classes. - def earth_location_itrf(self, time=None): - """Returns observatory geocentric position as an astropy - EarthLocation object. For observatories where this is not + def earth_location_itrf( + self, time: Optional[astropy.time.Time] = None + ) -> Union[None, np.ndarray]: + """Returns observatory geocentric position as an astropy EarthLocation object. + + For observatories where this is not relevant, None can be returned. The location is in the International Terrestrial Reference Frame (ITRF). @@ -319,8 +336,9 @@ def earth_location_itrf(self, time=None): """ return None - def get_gcrs(self, t, ephem=None): - """Return position vector of observatory in GCRS + def get_gcrs(self, t: astropy.time.Time, ephem=None): + """Return position vector of observatory in GCRS. + t is an astropy.Time or array of astropy.Time objects ephem is a link to an ephemeris file. Needed for SSB observatory Returns a 3-vector of Quantities representing the position @@ -329,14 +347,17 @@ def get_gcrs(self, t, ephem=None): raise NotImplementedError @property - def timescale(self): - """Returns the timescale that TOAs from this observatory will be in, - once any clock corrections have been applied. This should be a + def timescale(self) -> str: + """Returns the timescale that TOAs from this observatory will be in, once any clock corrections have been applied. + + This should be a string suitable to be passed directly to the scale argument of astropy.time.Time().""" raise NotImplementedError - def clock_corrections(self, t, limits="warn"): + def clock_corrections( + self, t: astropy.time.Time, limits: str = "warn" + ) -> u.Quantity: """Compute clock corrections for a Time array. Given an array-valued Time, return the clock corrections @@ -356,7 +377,7 @@ def clock_corrections(self, t, limits="warn"): return corr - def last_clock_correction_mjd(self): + def last_clock_correction_mjd(self) -> float: """Return the MJD of the last available clock correction. Returns ``np.inf`` if no clock corrections are relevant. @@ -365,6 +386,7 @@ def last_clock_correction_mjd(self): if self.include_gps: _load_gps_clock() + assert _gps_clock is not None t = min(t, _gps_clock.last_correction_mjd()) if self.include_bipm: _load_bipm_clock(self.bipm_version) @@ -374,7 +396,13 @@ def last_clock_correction_mjd(self): ) return t - def get_TDBs(self, t, method="default", ephem=None, options=None): + def get_TDBs( + self, + t: astropy.time.Time, + method: Union[str, Callable] = "default", + ephem: Optional[str] = None, + options: Optional[dict] = None, + ): """This is a high level function for converting TOAs to TDB time scale. Different method can be applied to obtain the result. Current supported @@ -409,13 +437,13 @@ def get_TDBs(self, t, method="default", ephem=None, options=None): t = Time([t]) if t.scale == "tdb": return t - # Check the method. This pattern is from numpy minimize - meth = "_custom" if callable(method) else method.lower() if options is None: options = {} - if meth == "_custom": + if callable(method): options = dict(options) return method(t, **options) + else: + meth = method.lower() if meth == "default": return self._get_TDB_default(t, ephem) elif meth == "ephemeris": @@ -428,17 +456,17 @@ def get_TDBs(self, t, method="default", ephem=None, options=None): else: raise ValueError(f"Unknown method '{method}'.") - def _get_TDB_default(self, t, ephem): + def _get_TDB_default(self, t: astropy.time.Time, ephem): return t.tdb - def _get_TDB_ephem(self, t, ephem): + def _get_TDB_ephem(self, t: astropy.time.Time, ephem): """Read the ephem TDB-TT column. This column is provided by DE4XXt version of ephemeris. """ raise NotImplementedError - def posvel(self, t, ephem, group=None): + def posvel(self, t: astropy.time.Time, ephem, group=None): """Return observatory position and velocity for the given times. Position is relative to solar system barycenter; times are @@ -451,7 +479,7 @@ def posvel(self, t, ephem, group=None): def get_observatory( - name, include_gps=None, include_bipm=None, bipm_version=bipm_default + name: str, include_gps=None, include_bipm=None, bipm_version: str = bipm_default ): """Convenience function to get observatory object with options. diff --git a/src/pint/observatory/topo_obs.py b/src/pint/observatory/topo_obs.py index 1d7fe8868..33fc8b0e2 100644 --- a/src/pint/observatory/topo_obs.py +++ b/src/pint/observatory/topo_obs.py @@ -17,12 +17,15 @@ -------- :mod:`pint.observatory.special_locations` """ +import copy import json import os +from functools import cached_property from pathlib import Path -import copy +from typing import Optional import astropy.constants as c +import astropy.time import astropy.units as u import numpy as np from astropy.coordinates import EarthLocation @@ -36,9 +39,9 @@ NoClockCorrections, Observatory, bipm_default, + earth_location_distance, find_clock_file, get_observatory, - earth_location_distance, ) from pint.pulsar_mjd import Time from pint.solar_system_ephemerides import get_tdb_tt_ephem_geocenter, objPosVel_wrt_SSB @@ -149,36 +152,36 @@ class TopoObs(Observatory): def __init__( self, - name, + name: str, *, - fullname=None, - tempo_code=None, - itoa_code=None, - aliases=None, + fullname: Optional[str] = None, + tempo_code: Optional[str] = None, + itoa_code: Optional[str] = None, + aliases: Optional[list[str]] = None, location=None, itrf_xyz=None, - lat=None, - lon=None, + lat: Optional[float] = None, + lon: Optional[float] = None, height=None, - clock_file="", - clock_fmt="tempo", + clock_file: str = "", + clock_fmt: str = "tempo", clock_dir=None, - include_gps=True, - include_bipm=True, - bipm_version=bipm_default, + include_gps: bool = True, + include_bipm: bool = True, + bipm_version: str = bipm_default, origin=None, - overwrite=False, - bogus_last_correction=False, + overwrite: bool = False, + bogus_last_correction: bool = False, ): input_values = [lat is not None, lon is not None, height is not None] - if sum(input_values) > 0 and sum(input_values) < 3: + if any(input_values) and not all(input_values): raise ValueError("All of lat, lon, height are required for observatory") input_values = [ location is not None, itrf_xyz is not None, (lat is not None and lon is not None and height is not None), ] - if sum(input_values) == 0: + if not any(input_values): raise ValueError( f"EarthLocation, ITRF coordinates, or lat/lon/height are required for observatory '{name}'" ) @@ -209,11 +212,12 @@ def __init__( # Save clock file info, the data will be read only if clock # corrections for this site are requested. - self.clock_files = [clock_file] if isinstance(clock_file, str) else clock_file - self.clock_files = [c for c in self.clock_files if c != ""] - self.clock_fmt = clock_fmt + clock_files: list[str] = ( + [clock_file] if isinstance(clock_file, str) else clock_file + ) + self.clock_files: list[str] = [c for c in clock_files if c != ""] + self.clock_fmt: str = clock_fmt self.clock_dir = clock_dir - self._clock = None # The ClockFile objects, will be read on demand # If using TEMPO time.dat we need to know the 1-char tempo-style # observatory code. @@ -315,10 +319,9 @@ def separation(self, other, method="cartesian"): def earth_location_itrf(self, time=None): return self.location - def _load_clock_corrections(self): - if self._clock is not None: - return - self._clock = [] + @cached_property + def _clock(self) -> list: + clock = [] for cf in self.clock_files: if cf == "": continue @@ -326,7 +329,7 @@ def _load_clock_corrections(self): if isinstance(cf, dict): kwargs.update(cf) cf = kwargs.pop("name") - self._clock.append( + clock.append( find_clock_file( cf, format=self.clock_fmt, @@ -334,8 +337,11 @@ def _load_clock_corrections(self): **kwargs, ) ) + return clock - def clock_corrections(self, t, limits="warn"): + def clock_corrections( + self, t: astropy.time.Time, limits: str = "warn" + ) -> u.Quantity: """Compute the total clock corrections, Parameters @@ -344,17 +350,16 @@ def clock_corrections(self, t, limits="warn"): The time when the clock correcions are applied. """ - corr = super().clock_corrections(t, limits=limits) - # Read clock file if necessary - self._load_clock_corrections() + corr: u.Quantity = super().clock_corrections(t, limits=limits) if self._clock: log.info( f"Applying observatory clock corrections for observatory='{self.name}'." ) for clock in self._clock: corr += clock.evaluate(t, limits=limits) - elif self.clock_files: + # clock_files is not empty, but no clock corrections found + # FIXME: what if only some were found? msg = f"No clock corrections found for observatory {self.name} taken from file {self.clock_files}" if limits == "warn": log.warning(msg) @@ -365,19 +370,18 @@ def clock_corrections(self, t, limits="warn"): log.info(f"Observatory {self.name} requires no clock corrections.") return corr - def last_clock_correction_mjd(self): + def last_clock_correction_mjd(self) -> float: """Return the MJD of the last clock correction. Combines constraints based on Earth orientation parameters and on the available clock corrections specific to the telescope. """ t = super().last_clock_correction_mjd() - self._load_clock_corrections() for clock in self._clock: t = min(t, clock.last_correction_mjd()) return t - def _get_TDB_ephem(self, t, ephem): + def _get_TDB_ephem(self, t: astropy.time.Time, ephem): """Read the ephem TDB-TT column. This column is provided by DE4XXt version of ephemeris. This function is only @@ -406,7 +410,7 @@ def _get_TDB_ephem(self, t, ephem): location=self.earth_location_itrf(), ) - def get_gcrs(self, t, ephem=None): + def get_gcrs(self, t: astropy.time.Time, ephem=None): """Return position vector of TopoObs in GCRS Parameters @@ -423,7 +427,7 @@ def get_gcrs(self, t, ephem=None): ) return obs_geocenter_pv.pos - def posvel(self, t, ephem, group=None): + def posvel(self, t: astropy.time.Time, ephem, group=None): if t.isscalar: t = Time([t]) earth_pv = objPosVel_wrt_SSB("earth", t, ephem) From 32b28241ca048f9b3f53dd2707d439d064d78240 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Sat, 10 Feb 2024 17:31:57 +0000 Subject: [PATCH 02/28] Fix python 3.8-incompatible syntax --- src/pint/observatory/__init__.py | 18 ++++++++-------- src/pint/observatory/topo_obs.py | 35 +++++++++++++++++++++++++++----- 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/src/pint/observatory/__init__.py b/src/pint/observatory/__init__.py index 2935613fe..65f4bcf5e 100644 --- a/src/pint/observatory/__init__.py +++ b/src/pint/observatory/__init__.py @@ -28,7 +28,7 @@ from copy import deepcopy from io import StringIO from pathlib import Path -from typing import Optional, Union +from typing import Optional, Union, List, Dict import astropy.coordinates import astropy.time @@ -140,31 +140,31 @@ class Observatory: """ fullname: str - aliases: list[str] + aliases: List[str] include_gps: bool include_bipm: bool bipm_version: str # This is a dict containing all defined Observatory instances, # keyed on standard observatory name. - _registry: dict[str, "Observatory"] = {} + _registry: Dict[str, "Observatory"] = {} # This is a dict mapping any defined aliases to the corresponding # standard name. - _alias_map: dict[str, str] = {} + _alias_map: Dict[str, str] = {} def __init__( self, name: str, fullname: Optional[str] = None, - aliases: Optional[list[str]] = None, + aliases: Optional[List[str]] = None, include_gps: bool = True, include_bipm: bool = True, bipm_version: str = bipm_default, overwrite: bool = False, ): self._name: str = name.lower() - self._aliases: list[str] = ( + self._aliases: List[str] = ( list(set(map(str.lower, aliases))) if aliases is not None else [] ) if aliases is not None: @@ -195,7 +195,7 @@ def _register(cls, obs: "Observatory", name: str): cls._registry[name.lower()] = obs @classmethod - def _add_aliases(cls, obs: "Observatory", aliases: list[str]): + def _add_aliases(cls, obs: "Observatory", aliases: List[str]): """Add aliases for the specified Observatory. Aliases should be given as a list. If any of the new aliases are already in use, they will be replaced. Aliases are not checked against the @@ -243,7 +243,7 @@ def names(cls): return cls._registry.keys() @classmethod - def names_and_aliases(cls) -> dict[str, list[str]]: + def names_and_aliases(cls) -> Dict[str, List[str]]: """List all observatories and their aliases""" import pint.observatory.topo_obs # noqa import pint.observatory.special_locations # noqa @@ -259,7 +259,7 @@ def name(self) -> str: return self._name @property - def aliases(self) -> list[str]: + def aliases(self) -> List[str]: return self._aliases @classmethod diff --git a/src/pint/observatory/topo_obs.py b/src/pint/observatory/topo_obs.py index 33fc8b0e2..4afc341e9 100644 --- a/src/pint/observatory/topo_obs.py +++ b/src/pint/observatory/topo_obs.py @@ -22,7 +22,7 @@ import os from functools import cached_property from pathlib import Path -from typing import Optional +from typing import Optional, Union, List import astropy.constants as c import astropy.time @@ -150,6 +150,31 @@ class TopoObs(Observatory): """ + tempo_code: Optional[str] + """One-character TEMPO code.""" + itoa_code: Optional[str] + """Two-character ITOA code.""" + location: EarthLocation + """Location of the observatory.""" + clock_files: List[str] + """List of files to read for clock corrections. If empty, no clock corrections are applied.""" + clock_fmt: str + """Format of the clock files. + + See :class:`pint.observatory.clock_file.ClockFile` for allowed values. + """ + bogus_last_correction: bool + """Clock correction files include a bogus last correction. + + This is common with TEMPO/TEMPO2 clock files since neither program does + a good job with times past the end ot the table. It makes detecting values + past the end of real calibration difficult if it's not marked as bogus. + """ + clock_dir: Optional[Union[str, Path]] + """Where to look for the clock files.""" + origin: Optional[str] + """Documentation of the origin/author/date for the information.""" + def __init__( self, name: str, @@ -157,7 +182,7 @@ def __init__( fullname: Optional[str] = None, tempo_code: Optional[str] = None, itoa_code: Optional[str] = None, - aliases: Optional[list[str]] = None, + aliases: Optional[List[str]] = None, location=None, itrf_xyz=None, lat: Optional[float] = None, @@ -169,7 +194,7 @@ def __init__( include_gps: bool = True, include_bipm: bool = True, bipm_version: str = bipm_default, - origin=None, + origin: Optional[str] = None, overwrite: bool = False, bogus_last_correction: bool = False, ): @@ -212,10 +237,10 @@ def __init__( # Save clock file info, the data will be read only if clock # corrections for this site are requested. - clock_files: list[str] = ( + clock_files: List[str] = ( [clock_file] if isinstance(clock_file, str) else clock_file ) - self.clock_files: list[str] = [c for c in clock_files if c != ""] + self.clock_files: List[str] = [c for c in clock_files if c != ""] self.clock_fmt: str = clock_fmt self.clock_dir = clock_dir From 10a7248aa4dd78bc4c4a7c9dc4bac6ca9847ee8f Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Sat, 10 Feb 2024 17:46:17 +0000 Subject: [PATCH 03/28] Fill out typing more --- src/pint/observatory/__init__.py | 27 +++++++++++++-------------- src/pint/observatory/topo_obs.py | 32 +++++++++++++++----------------- 2 files changed, 28 insertions(+), 31 deletions(-) diff --git a/src/pint/observatory/__init__.py b/src/pint/observatory/__init__.py index 65f4bcf5e..8803b655b 100644 --- a/src/pint/observatory/__init__.py +++ b/src/pint/observatory/__init__.py @@ -39,7 +39,7 @@ from pint.config import runtimefile from pint.pulsar_mjd import Time -from pint.utils import interesting_lines +from pint.utils import interesting_lines, PosVel # Include any files that define observatories here. This will start # with the standard distribution files, then will read any system- or @@ -90,7 +90,7 @@ class ClockCorrectionOutOfRange(ClockCorrectionError): _bipm_clock_versions = {} -def _load_gps_clock(): +def _load_gps_clock() -> None: global _gps_clock if _gps_clock is None: log.info("Loading global GPS clock file") @@ -100,7 +100,7 @@ def _load_gps_clock(): ) -def _load_bipm_clock(bipm_version: str): +def _load_bipm_clock(bipm_version: str) -> None: bipm_version = bipm_version.lower() if bipm_version not in _bipm_clock_versions: try: @@ -140,7 +140,6 @@ class Observatory: """ fullname: str - aliases: List[str] include_gps: bool include_bipm: bool bipm_version: str @@ -184,7 +183,7 @@ def __init__( Observatory._register(self, name) @classmethod - def _register(cls, obs: "Observatory", name: str): + def _register(cls, obs: "Observatory", name: str) -> None: """Add an observatory to the registry using the specified name (which will be converted to lower case). If an existing observatory @@ -195,7 +194,7 @@ def _register(cls, obs: "Observatory", name: str): cls._registry[name.lower()] = obs @classmethod - def _add_aliases(cls, obs: "Observatory", aliases: List[str]): + def _add_aliases(cls, obs: "Observatory", aliases: List[str]) -> None: """Add aliases for the specified Observatory. Aliases should be given as a list. If any of the new aliases are already in use, they will be replaced. Aliases are not checked against the @@ -207,7 +206,7 @@ def _add_aliases(cls, obs: "Observatory", aliases: List[str]): cls._alias_map[a.lower()] = obs.name @staticmethod - def gps_correction(t: astropy.time.Time, limits: str = "warn"): + def gps_correction(t: astropy.time.Time, limits: str = "warn") -> u.Quantity: """Compute the GPS clock corrections for times t.""" log.info("Applying GPS to UTC clock correction (~few nanoseconds)") _load_gps_clock() @@ -217,7 +216,7 @@ def gps_correction(t: astropy.time.Time, limits: str = "warn"): @staticmethod def bipm_correction( t: astropy.time.Time, bipm_version: str = bipm_default, limits: str = "warn" - ): + ) -> u.Quantity: """Compute the GPS clock corrections for times t.""" log.info(f"Applying TT(TAI) to TT({bipm_version}) clock correction (~27 us)") tt2tai = 32.184 * 1e6 * u.us @@ -466,7 +465,7 @@ def _get_TDB_ephem(self, t: astropy.time.Time, ephem): """ raise NotImplementedError - def posvel(self, t: astropy.time.Time, ephem, group=None): + def posvel(self, t: astropy.time.Time, ephem, group=None) -> PosVel: """Return observatory position and velocity for the given times. Position is relative to solar system barycenter; times are @@ -519,14 +518,14 @@ def get_observatory( return Observatory.get(name) -def earth_location_distance(loc1, loc2): +def earth_location_distance(loc1: EarthLocation, loc2: EarthLocation) -> u.Quantity: """Compute the distance between two EarthLocations.""" return ( sum((u.Quantity(loc1.to_geocentric()) - u.Quantity(loc2.to_geocentric())) ** 2) ) ** 0.5 -def compare_t2_observatories_dat(t2dir=None): +def compare_t2_observatories_dat(t2dir: Optional[str] = None) -> Dict[str, List[Dict]]: """Read a tempo2 observatories.dat file and compare with PINT Produces a report including lines that can be added to PINT's @@ -617,7 +616,7 @@ def compare_t2_observatories_dat(t2dir=None): return report -def compare_tempo_obsys_dat(tempodir=None): +def compare_tempo_obsys_dat(tempodir: Optional[str] = None) -> Dict[str, List[Dict]]: """Read a tempo obsys.dat file and compare with PINT. Produces a report including lines that can be added to PINT's @@ -657,8 +656,8 @@ def compare_tempo_obsys_dat(tempodir=None): y = float(line_io.read(15)) z = float(line_io.read(15)) line_io.read(2) - icoord = line_io.read(1).strip() - icoord = int(icoord) if icoord else 0 + icoord_str = line_io.read(1).strip() + icoord = int(icoord_str) if icoord_str else 0 line_io.read(2) obsnam = line_io.read(20).strip().lower() tempo_code = line_io.read(1) diff --git a/src/pint/observatory/topo_obs.py b/src/pint/observatory/topo_obs.py index 4afc341e9..0b810aa12 100644 --- a/src/pint/observatory/topo_obs.py +++ b/src/pint/observatory/topo_obs.py @@ -45,7 +45,7 @@ ) from pint.pulsar_mjd import Time from pint.solar_system_ephemerides import get_tdb_tt_ephem_geocenter, objPosVel_wrt_SSB -from pint.utils import has_astropy_unit, open_or_use +from pint.utils import has_astropy_unit, open_or_use, PosVel # environment variables that can override clock location and observatory location pint_obs_env_var = "PINT_OBS_OVERRIDE" @@ -277,7 +277,7 @@ def __init__( overwrite=overwrite, ) - def __repr__(self): + def __repr__(self) -> str: aliases = [f"'{x}'" for x in self.aliases] origin = ( f"{self.fullname}\n{self.origin}" @@ -309,7 +309,7 @@ def get_json(self): """Return as a JSON string""" return json.dumps(self.get_dict()) - def separation(self, other, method="cartesian"): + def separation(self, other: "TopoObs", method: str = "cartesian"): """Return separation between two TopoObs objects Parameters @@ -341,7 +341,7 @@ def separation(self, other, method="cartesian"): ) return (c.R_earth * dsigma).to(u.m, equivalencies=u.dimensionless_angles()) - def earth_location_itrf(self, time=None): + def earth_location_itrf(self, time=None) -> EarthLocation: return self.location @cached_property @@ -364,9 +364,7 @@ def _clock(self) -> list: ) return clock - def clock_corrections( - self, t: astropy.time.Time, limits: str = "warn" - ) -> u.Quantity: + def clock_corrections(self, t: Time, limits: str = "warn") -> u.Quantity: """Compute the total clock corrections, Parameters @@ -406,7 +404,7 @@ def last_clock_correction_mjd(self) -> float: t = min(t, clock.last_correction_mjd()) return t - def _get_TDB_ephem(self, t: astropy.time.Time, ephem): + def _get_TDB_ephem(self, t: Time, ephem) -> Time: """Read the ephem TDB-TT column. This column is provided by DE4XXt version of ephemeris. This function is only @@ -418,8 +416,8 @@ def _get_TDB_ephem(self, t: astropy.time.Time, ephem): # Topocenter to Geocenter # Since earth velocity is not going to change a lot in 3ms. The # differences between TT and TDB can be ignored. - earth_pv = objPosVel_wrt_SSB("earth", t.tdb, ephem) - obs_geocenter_pv = gcrs_posvel_from_itrf( + earth_pv: PosVel = objPosVel_wrt_SSB("earth", t.tdb, ephem) + obs_geocenter_pv: PosVel = gcrs_posvel_from_itrf( self.earth_location_itrf(), t, obsname=self.name ) # NOTE @@ -447,22 +445,22 @@ def get_gcrs(self, t: astropy.time.Time, ephem=None): np.array a 3-vector of Quantities representing the position in GCRS coordinates. """ - obs_geocenter_pv = gcrs_posvel_from_itrf( + obs_geocenter_pv: PosVel = gcrs_posvel_from_itrf( self.earth_location_itrf(), t, obsname=self.name ) return obs_geocenter_pv.pos - def posvel(self, t: astropy.time.Time, ephem, group=None): + def posvel(self, t: astropy.time.Time, ephem, group=None) -> PosVel: if t.isscalar: t = Time([t]) - earth_pv = objPosVel_wrt_SSB("earth", t, ephem) - obs_geocenter_pv = gcrs_posvel_from_itrf( + earth_pv: PosVel = objPosVel_wrt_SSB("earth", t, ephem) + obs_geocenter_pv: PosVel = gcrs_posvel_from_itrf( self.earth_location_itrf(), t, obsname=self.name ) return obs_geocenter_pv + earth_pv -def export_all_clock_files(directory): +def export_all_clock_files(directory: Union[str, Path]) -> None: """Export all clock files PINT is using. This will export all the clock files PINT is using - every clock file used @@ -494,7 +492,7 @@ def export_all_clock_files(directory): clock.export(directory / Path(clock.filename).name) -def load_observatories(filename=observatories_json, overwrite=False): +def load_observatories(filename=observatories_json, overwrite: bool = False) -> None: """Load observatory definitions from JSON and create :class:`pint.observatory.topo_obs.TopoObs` objects, registering them Set `overwrite` to ``True`` if you want to re-read a file with updated definitions. @@ -528,7 +526,7 @@ def load_observatories(filename=observatories_json, overwrite=False): TopoObs(name=obsname, **obsdict) -def load_observatories_from_usual_locations(clear=False): +def load_observatories_from_usual_locations(clear: bool = False) -> None: """Load observatories from the default JSON file as well as ``$PINT_OBS_OVERRIDE``, optionally clearing the registry Running with ``clear=True`` will return PINT to the state it is on import. From f39772687f24604cfdee26dd1eb95b28f8502586 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Sat, 10 Feb 2024 18:16:46 +0000 Subject: [PATCH 04/28] Improve attribute documentation --- src/pint/observatory/__init__.py | 13 +++++++++++++ src/pint/observatory/topo_obs.py | 4 ++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/pint/observatory/__init__.py b/src/pint/observatory/__init__.py index 8803b655b..8a5663b7d 100644 --- a/src/pint/observatory/__init__.py +++ b/src/pint/observatory/__init__.py @@ -140,9 +140,13 @@ class Observatory: """ fullname: str + """Full human-readable name of the observatory.""" include_gps: bool + """Whether to include GPS clock corrections.""" include_bipm: bool + """Whether to include BIPM clock corrections.""" bipm_version: str + """Version of the BIPM clock file to use.""" # This is a dict containing all defined Observatory instances, # keyed on standard observatory name. @@ -255,10 +259,19 @@ def names_and_aliases(cls) -> Dict[str, List[str]]: @property def name(self) -> str: + """Short name of the observatory. + + This is the name used in TOA files and in the observatory registry. + """ return self._name @property def aliases(self) -> List[str]: + """List of aliases for the observatory. + + These are short names also used to specify this observatory. + Includes ITOA and TEMPO codes, and any other common names. + """ return self._aliases @classmethod diff --git a/src/pint/observatory/topo_obs.py b/src/pint/observatory/topo_obs.py index 0b810aa12..4a247e701 100644 --- a/src/pint/observatory/topo_obs.py +++ b/src/pint/observatory/topo_obs.py @@ -306,11 +306,11 @@ def get_dict(self): return {self.name: output} def get_json(self): - """Return as a JSON string""" + """Return as a JSON string.""" return json.dumps(self.get_dict()) def separation(self, other: "TopoObs", method: str = "cartesian"): - """Return separation between two TopoObs objects + """Return separation between two TopoObs objects. Parameters ---------- From 6da627a1ecae2e2a82410da87dd647a66b91e34e Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Wed, 21 Feb 2024 18:50:07 +0000 Subject: [PATCH 05/28] Additional type annotations --- src/pint/observatory/__init__.py | 33 +++++++++++++++++--------------- src/pint/observatory/topo_obs.py | 18 ++++++++--------- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/src/pint/observatory/__init__.py b/src/pint/observatory/__init__.py index 8a5663b7d..ad889ded1 100644 --- a/src/pint/observatory/__init__.py +++ b/src/pint/observatory/__init__.py @@ -28,7 +28,7 @@ from copy import deepcopy from io import StringIO from pathlib import Path -from typing import Optional, Union, List, Dict +from typing import Optional, Union, List, Dict, Literal import astropy.coordinates import astropy.time @@ -348,7 +348,7 @@ def earth_location_itrf( """ return None - def get_gcrs(self, t: astropy.time.Time, ephem=None): + def get_gcrs(self, t: astropy.time.Time, ephem: Optional[str] = None): """Return position vector of observatory in GCRS. t is an astropy.Time or array of astropy.Time objects @@ -468,17 +468,17 @@ def get_TDBs( else: raise ValueError(f"Unknown method '{method}'.") - def _get_TDB_default(self, t: astropy.time.Time, ephem): + def _get_TDB_default(self, t: astropy.time.Time, ephem: Optional[str]): return t.tdb - def _get_TDB_ephem(self, t: astropy.time.Time, ephem): + def _get_TDB_ephem(self, t: astropy.time.Time, ephem: Optional[str]): """Read the ephem TDB-TT column. This column is provided by DE4XXt version of ephemeris. """ raise NotImplementedError - def posvel(self, t: astropy.time.Time, ephem, group=None) -> PosVel: + def posvel(self, t: astropy.time.Time, ephem: Optional[str], group=None) -> PosVel: """Return observatory position and velocity for the given times. Position is relative to solar system barycenter; times are @@ -491,7 +491,10 @@ def posvel(self, t: astropy.time.Time, ephem, group=None) -> PosVel: def get_observatory( - name: str, include_gps=None, include_bipm=None, bipm_version: str = bipm_default + name: str, + include_gps: Optional[bool] = None, + include_bipm: Optional[bool] = None, + bipm_version: str = bipm_default, ): """Convenience function to get observatory object with options. @@ -753,7 +756,7 @@ def convert_angle(x): return report -def list_last_correction_mjds(): +def list_last_correction_mjds() -> None: """Print out a list of the last MJD each clock correction is good for. Each observatory lists the clock files it uses and their last dates, @@ -784,7 +787,7 @@ def list_last_correction_mjds(): print(f" {c.friendly_name:<20} MISSING") -def update_clock_files(bipm_versions=None): +def update_clock_files(bipm_versions: Optional[list[str]] = None) -> None: """Obtain an up-to-date version of all clock files. This up-to-date version will be stored in the Astropy cache; @@ -826,13 +829,13 @@ def update_clock_files(bipm_versions=None): # Both topo_obs and special_locations need this def find_clock_file( - name, - format, - bogus_last_correction=False, - url_base=None, - clock_dir=None, - valid_beyond_ends=False, -): + name: str, + format: Literal["tempo", "tempo2"], + bogus_last_correction: bool = False, + url_base: Optional[str] = None, + clock_dir: Union[str, Path, None] = None, + valid_beyond_ends: bool = False, +) -> "ClockFile": """Locate and return a ClockFile in one of several places. PINT looks for clock files in three places, in order: diff --git a/src/pint/observatory/topo_obs.py b/src/pint/observatory/topo_obs.py index 4a247e701..53ad6362e 100644 --- a/src/pint/observatory/topo_obs.py +++ b/src/pint/observatory/topo_obs.py @@ -22,7 +22,7 @@ import os from functools import cached_property from pathlib import Path -from typing import Optional, Union, List +from typing import Optional, Union, List, Any import astropy.constants as c import astropy.time @@ -183,14 +183,14 @@ def __init__( tempo_code: Optional[str] = None, itoa_code: Optional[str] = None, aliases: Optional[List[str]] = None, - location=None, + location: Optional[EarthLocation] = None, itrf_xyz=None, lat: Optional[float] = None, lon: Optional[float] = None, height=None, clock_file: str = "", clock_fmt: str = "tempo", - clock_dir=None, + clock_dir: Union[str, Path, None] = None, include_gps: bool = True, include_bipm: bool = True, bipm_version: str = bipm_default, @@ -287,10 +287,10 @@ def __repr__(self) -> str: return f"TopoObs('{self.name}' ({','.join(aliases)}) at [{self.location.x}, {self.location.y} {self.location.z}]:\n{origin})" @property - def timescale(self): + def timescale(self) -> str: return "utc" - def get_dict(self): + def get_dict(self) -> dict[str, dict[str, Any]]: """Return as a dict with limited/changed info""" # start with the default __dict__ # copy some attributes to rename them and remove those that aren't needed for initialization @@ -305,11 +305,11 @@ def get_dict(self): output["itrf_xyz"] = [x.to_value(u.m) for x in self.location.geocentric] return {self.name: output} - def get_json(self): + def get_json(self) -> str: """Return as a JSON string.""" return json.dumps(self.get_dict()) - def separation(self, other: "TopoObs", method: str = "cartesian"): + def separation(self, other: "TopoObs", method: str = "cartesian") -> u.Quantity: """Return separation between two TopoObs objects. Parameters @@ -433,7 +433,7 @@ def _get_TDB_ephem(self, t: Time, ephem) -> Time: location=self.earth_location_itrf(), ) - def get_gcrs(self, t: astropy.time.Time, ephem=None): + def get_gcrs(self, t: astropy.time.Time, ephem: Optional[str] = None): """Return position vector of TopoObs in GCRS Parameters @@ -450,7 +450,7 @@ def get_gcrs(self, t: astropy.time.Time, ephem=None): ) return obs_geocenter_pv.pos - def posvel(self, t: astropy.time.Time, ephem, group=None) -> PosVel: + def posvel(self, t: astropy.time.Time, ephem: Optional[str], group=None) -> PosVel: if t.isscalar: t = Time([t]) earth_pv: PosVel = objPosVel_wrt_SSB("earth", t, ephem) From 23f5d6d6025d6af721328d7e751474f5f745a839 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Wed, 21 Feb 2024 18:54:43 +0000 Subject: [PATCH 06/28] Fix list->List --- src/pint/observatory/__init__.py | 2 +- src/pint/observatory/topo_obs.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pint/observatory/__init__.py b/src/pint/observatory/__init__.py index ad889ded1..bccb40d53 100644 --- a/src/pint/observatory/__init__.py +++ b/src/pint/observatory/__init__.py @@ -787,7 +787,7 @@ def list_last_correction_mjds() -> None: print(f" {c.friendly_name:<20} MISSING") -def update_clock_files(bipm_versions: Optional[list[str]] = None) -> None: +def update_clock_files(bipm_versions: Optional[List[str]] = None) -> None: """Obtain an up-to-date version of all clock files. This up-to-date version will be stored in the Astropy cache; diff --git a/src/pint/observatory/topo_obs.py b/src/pint/observatory/topo_obs.py index 53ad6362e..8cbe1ddc0 100644 --- a/src/pint/observatory/topo_obs.py +++ b/src/pint/observatory/topo_obs.py @@ -404,7 +404,7 @@ def last_clock_correction_mjd(self) -> float: t = min(t, clock.last_correction_mjd()) return t - def _get_TDB_ephem(self, t: Time, ephem) -> Time: + def _get_TDB_ephem(self, t: Time, ephem: Optional[str]) -> Time: """Read the ephem TDB-TT column. This column is provided by DE4XXt version of ephemeris. This function is only From 6fda954b49bcf79c9d25c3ffa59ea9dc2b6baf37 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Wed, 21 Feb 2024 18:58:17 +0000 Subject: [PATCH 07/28] dict -> Dict --- src/pint/observatory/topo_obs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pint/observatory/topo_obs.py b/src/pint/observatory/topo_obs.py index 8cbe1ddc0..3e47d9b5b 100644 --- a/src/pint/observatory/topo_obs.py +++ b/src/pint/observatory/topo_obs.py @@ -22,7 +22,7 @@ import os from functools import cached_property from pathlib import Path -from typing import Optional, Union, List, Any +from typing import Optional, Union, List, Any, Dict import astropy.constants as c import astropy.time @@ -290,7 +290,7 @@ def __repr__(self) -> str: def timescale(self) -> str: return "utc" - def get_dict(self) -> dict[str, dict[str, Any]]: + def get_dict(self) -> Dict[str, Dict[str, Any]]: """Return as a dict with limited/changed info""" # start with the default __dict__ # copy some attributes to rename them and remove those that aren't needed for initialization From 4893357e0ea91cd5a8905e9509af0416f364bc02 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Sun, 25 Feb 2024 14:27:55 +0000 Subject: [PATCH 08/28] Initial mypy setup passing --- requirements_dev.txt | 4 +++ setup.cfg | 64 ++++++++++++++++++++++++++++++++++++++++++ src/pint/logging.py | 2 +- src/pint/pulsar_mjd.py | 13 ++------- src/pint/py.typed | 0 tox.ini | 10 +++++++ 6 files changed, 82 insertions(+), 11 deletions(-) create mode 100644 src/pint/py.typed diff --git a/requirements_dev.txt b/requirements_dev.txt index a00511840..90348b340 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -41,3 +41,7 @@ loguru gprof2dot py-cpuinfo pytest-xdist +mypy==1.8.0 +GitPython +types-setuptools +types-tqdm diff --git a/setup.cfg b/setup.cfg index ad2570a5f..9e841acc0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -129,3 +129,67 @@ skip_glob = src/pint/extern/* include_trailing_comma = True combine_as_imports = True + +[mypy] +warn_unused_configs = True +files = src/pint + +[mypy-pint.templates.*] +ignore_errors = True + +[mypy-pint.derived_quantities] +ignore_errors = True + +[mypy-pint.observatory.*] +ignore_errors = True + +[mypy-pint.models.stand_alone_psr_binaries.*] +ignore_errors = True + +[mypy-pint.models.timing_model] +ignore_errors = True + +[mypy-pint.models.pulsar_binary] +ignore_errors = True + +[mypy-pint.fitter] +ignore_errors = True + +[mypy-pint.polycos] +ignore_errors = True + +[mypy-pint.output.publish] +ignore_errors = True + +[mypy-pint.gridutils] +ignore_errors = True + +[mypy-pint.scripts.*] +ignore_errors = True + +[mypy-pint.pintk.plk] +ignore_errors = True + +[mypy-astropy.*] +ignore_missing_imports = True + +[mypy-erfa] +ignore_missing_imports = True + +[mypy-emcee] +ignore_missing_imports = True + +[mypy-jplephem] +ignore_missing_imports = True + +[mypy-numdifftools] +ignore_missing_imports = True + +[mypy-pylab] +ignore_missing_imports = True + +[mypy-scipy.*] +ignore_missing_imports = True + +[mypy-uncertainties] +ignore_missing_imports = True \ No newline at end of file diff --git a/src/pint/logging.py b/src/pint/logging.py index 0f4f67021..5187091b3 100644 --- a/src/pint/logging.py +++ b/src/pint/logging.py @@ -72,7 +72,7 @@ # https://loguru.readthedocs.io/en/stable/api/logger.html#color showwarning_ = warnings.showwarning -warning_onceregistry = {} +warning_onceregistry: dict[tuple[str, str], int] = {} # basic loguru level definitions from: # https://loguru.readthedocs.io/en/stable/api/logger.html diff --git a/src/pint/pulsar_mjd.py b/src/pint/pulsar_mjd.py index 3c0f14d8d..55f8f6b61 100644 --- a/src/pint/pulsar_mjd.py +++ b/src/pint/pulsar_mjd.py @@ -34,13 +34,6 @@ from astropy.time import Time from astropy.time.formats import TimeFormat -try: - maketrans = str.maketrans -except AttributeError: - # fallback for Python 2 - from string import maketrans - - # This check is implemented in pint.utils, but we want to avoid circular imports if np.finfo(np.longdouble).eps > 2e-19: import warnings @@ -303,7 +296,7 @@ def fortran_float(x): """ try: # First treat it as a string, wih d->e - return float(x.translate(maketrans("Dd", "ee"))) + return float(x.translate(str.maketrans("Dd", "ee"))) except AttributeError: # If that didn't work it may already be a numeric type return float(x) @@ -361,7 +354,7 @@ def str2longdouble(str_data): """ if not isinstance(str_data, (str, bytes)): raise TypeError("Need a string: {!r}".format(str_data)) - return np.longdouble(str_data.translate(maketrans("Dd", "ee"))) + return np.longdouble(str_data.translate(str.maketrans("Dd", "ee"))) # Simplified functions: These core functions, if they can be made to work @@ -453,7 +446,7 @@ def mjds_to_jds_pulsar(mjd1, mjd2): def _str_to_mjds(s): ss = s.lower().strip() if "e" in ss or "d" in ss: - ss = ss.translate(maketrans("d", "e")) + ss = ss.translate(str.maketrans("d", "e")) num, expon = ss.split("e") expon = int(expon) if expon < 0: diff --git a/src/pint/py.typed b/src/pint/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/tox.ini b/tox.ini index 05ed53ca3..67d010f84 100644 --- a/tox.ini +++ b/tox.ini @@ -13,6 +13,7 @@ envlist = report codestyle black + mypy py{38,39,310,311,312}-test{,-alldeps,-devdeps}{,-cov} skip_missing_interpreters = True @@ -136,3 +137,12 @@ deps = commands = black --check src tests examples +[testenv:mypy] +changedir = . +description = use mypy +deps = + mypy==1.8.0 + GitPython + types-setuptools + types-tqdm +commands = mypy \ No newline at end of file From a78d8c5fd8cc659e92a3aa5fb730ca45985bf07e Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Sun, 25 Feb 2024 14:31:34 +0000 Subject: [PATCH 09/28] Set up CI action --- .github/workflows/ci_test.yml | 3 +++ .gitignore | 3 +++ 2 files changed, 6 insertions(+) diff --git a/.github/workflows/ci_test.yml b/.github/workflows/ci_test.yml index f43615c9c..e296c0769 100644 --- a/.github/workflows/ci_test.yml +++ b/.github/workflows/ci_test.yml @@ -35,6 +35,9 @@ jobs: - os: ubuntu-latest python: '3.12' tox_env: 'notebooks' + - os: ubuntu-latest + python: '3.12' + tox_env: 'mypy' # - os: ubuntu-latest # python: '3.8' # tox_env: 'docs' diff --git a/.gitignore b/.gitignore index c95fdb4b1..549cd5bea 100644 --- a/.gitignore +++ b/.gitignore @@ -139,3 +139,6 @@ tests/datafile/par_*.par tests/datafile/fake_toas.tim tests/datafile/*.converted.par tests/datafile/_test_pintempo.out + +# mypy +.mypy_cache \ No newline at end of file From 8c0f22fb4d935d0054bed54efee19e5626bf4637 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Sun, 25 Feb 2024 15:02:05 +0000 Subject: [PATCH 10/28] Enforce python3.12 and fix syntaxwarnings --- src/pint/derived_quantities.py | 42 ++++++++++++++++------------------ src/pint/output/publish.py | 6 ++--- tox.ini | 1 + 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/src/pint/derived_quantities.py b/src/pint/derived_quantities.py index 104b73d8d..d8cb66390 100644 --- a/src/pint/derived_quantities.py +++ b/src/pint/derived_quantities.py @@ -34,7 +34,7 @@ p=[u.Hz, u.s], pd=[u.Hz / u.s, u.s / u.s], pdd=[u.Hz / u.s**2, u.s / u.s**2] ) def p_to_f(p, pd, pdd=None): - """Converts P, Pdot to F, Fdot (or vice versa) + r"""Converts P, Pdot to F, Fdot (or vice versa) Convert period, period derivative and period second derivative (if supplied) to the equivalent frequency counterparts. @@ -81,7 +81,7 @@ def p_to_f(p, pd, pdd=None): pdorfderr=[u.Hz / u.s, u.s / u.s], ) def pferrs(porf, porferr, pdorfd=None, pdorfderr=None): - """Convert P, Pdot to F, Fdot with uncertainties (or vice versa). + r"""Convert P, Pdot to F, Fdot with uncertainties (or vice versa). Calculate the period or frequency errors and the Pdot or fdot errors from the opposite ones. @@ -129,7 +129,7 @@ def pferrs(porf, porferr, pdorfd=None, pdorfderr=None): @u.quantity_input(fo=u.Hz) def pulsar_age(f: u.Hz, fdot: u.Hz / u.s, n=3, fo=1e99 * u.Hz): - """Compute pulsar characteristic age + r"""Compute pulsar characteristic age Return the age of a pulsar given the spin frequency and frequency derivative. By default, the characteristic age @@ -172,7 +172,7 @@ def pulsar_age(f: u.Hz, fdot: u.Hz / u.s, n=3, fo=1e99 * u.Hz): @u.quantity_input(I=u.g * u.cm**2) def pulsar_edot(f: u.Hz, fdot: u.Hz / u.s, I=1.0e45 * u.g * u.cm**2): - """Compute pulsar spindown energy loss rate + r"""Compute pulsar spindown energy loss rate Return the pulsar `Edot` (:math:`\dot E`, in erg/s) given the spin frequency `f` and frequency derivative `fdot`. The NS moment of inertia is assumed to be @@ -208,7 +208,7 @@ def pulsar_edot(f: u.Hz, fdot: u.Hz / u.s, I=1.0e45 * u.g * u.cm**2): @u.quantity_input def pulsar_B(f: u.Hz, fdot: u.Hz / u.s): - """Compute pulsar surface magnetic field + r"""Compute pulsar surface magnetic field Return the estimated pulsar surface magnetic field strength given the spin frequency and frequency derivative. @@ -243,7 +243,7 @@ def pulsar_B(f: u.Hz, fdot: u.Hz / u.s): @u.quantity_input def pulsar_B_lightcyl(f: u.Hz, fdot: u.Hz / u.s): - """Compute pulsar magnetic field at the light cylinder + r"""Compute pulsar magnetic field at the light cylinder Return the estimated pulsar magnetic field strength at the light cylinder given the spin frequency and @@ -285,7 +285,7 @@ def pulsar_B_lightcyl(f: u.Hz, fdot: u.Hz / u.s): @u.quantity_input def mass_funct(pb: u.d, x: u.cm): - """Compute binary mass function from period and semi-major axis + r"""Compute binary mass function from period and semi-major axis Can handle scalar or array inputs. @@ -326,7 +326,7 @@ def mass_funct(pb: u.d, x: u.cm): @u.quantity_input def mass_funct2(mp: u.Msun, mc: u.Msun, i: u.deg): - """Compute binary mass function from masses and inclination + r"""Compute binary mass function from masses and inclination Can handle scalar or array inputs. @@ -371,7 +371,7 @@ def mass_funct2(mp: u.Msun, mc: u.Msun, i: u.deg): @u.quantity_input def pulsar_mass(pb: u.d, x: u.cm, mc: u.Msun, i: u.deg): - """Compute pulsar mass from orbital parameters + r"""Compute pulsar mass from orbital parameters Return the pulsar mass (in solar mass units) for a binary. Can handle scalar or array inputs. @@ -438,7 +438,7 @@ def pulsar_mass(pb: u.d, x: u.cm, mc: u.Msun, i: u.deg): @u.quantity_input(inc=u.deg, mpsr=u.solMass) def companion_mass(pb: u.d, x: u.cm, i=60.0 * u.deg, mp=1.4 * u.solMass): - """Commpute the companion mass from the orbital parameters + r"""Commpute the companion mass from the orbital parameters Compute companion mass for a binary system from orbital mechanics, not Shapiro delay. @@ -542,7 +542,7 @@ def companion_mass(pb: u.d, x: u.cm, i=60.0 * u.deg, mp=1.4 * u.solMass): @u.quantity_input def pbdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): - """Post-Keplerian orbital decay pbdot, assuming general relativity. + r"""Post-Keplerian orbital decay pbdot, assuming general relativity. pbdot (:math:`\dot P_B`) is the change in the binary orbital period due to emission of gravitational waves. @@ -605,7 +605,7 @@ def pbdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): @u.quantity_input def gamma(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): - """Post-Keplerian time dilation and gravitational redshift gamma, assuming general relativity. + r"""Post-Keplerian time dilation and gravitational redshift gamma, assuming general relativity. gamma (:math:`\gamma`) is the amplitude of the modification in arrival times caused by the varying gravitational redshift of the companion and time dilation in an elliptical orbit. The time delay is @@ -661,7 +661,7 @@ def gamma(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): @u.quantity_input def omdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): - """Post-Keplerian longitude of periastron precession rate omdot, assuming general relativity. + r"""Post-Keplerian longitude of periastron precession rate omdot, assuming general relativity. omdot (:math:`\dot \omega`) is the relativistic advance of periastron. Can handle scalar or array inputs. @@ -716,7 +716,7 @@ def omdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled): @u.quantity_input def sini(mp: u.Msun, mc: u.Msun, pb: u.d, x: u.cm): - """Post-Keplerian sine of inclination, assuming general relativity. + r"""Post-Keplerian sine of inclination, assuming general relativity. Can handle scalar or array inputs. @@ -916,14 +916,12 @@ def omdot_to_mtot(omdot: u.deg / u.yr, pb: u.d, e: u.dimensionless_unscaled): """ return ( ( - ( - omdot - / ( - 3 - * (const.G / const.c**3) ** (2.0 / 3) - * (pb / (2 * np.pi)) ** (-5.0 / 3) - * (1 - e**2) ** (-1) - ) + omdot + / ( + 3 + * (const.G / const.c**3) ** (2.0 / 3) + * (pb / (2 * np.pi)) ** (-5.0 / 3) + * (1 - e**2) ** (-1) ) ) ** (3.0 / 2) diff --git a/src/pint/output/publish.py b/src/pint/output/publish.py index 2cb82b0ac..e6380eb1f 100644 --- a/src/pint/output/publish.py +++ b/src/pint/output/publish.py @@ -259,7 +259,7 @@ def publish( ) tex.write("\\hline\n") - tex.write("\multicolumn{2}{c}{Measured Quantities} \\\\ \n") + tex.write("\\multicolumn{2}{c}{Measured Quantities} \\\\ \n") tex.write("\\hline\n") for fp in model.free_params: param = getattr(model, fp) @@ -273,7 +273,7 @@ def publish( tex.write("\\hline\n") if include_set_params: - tex.write("\multicolumn{2}{c}{Set Quantities} \\\\ \n") + tex.write("\\multicolumn{2}{c}{Set Quantities} \\\\ \n") tex.write("\\hline\n") for p in model.params: param = getattr(model, p) @@ -303,7 +303,7 @@ def publish( and getattr(model, p).quantity is not None ] if len(derived_params) > 0: - tex.write("\multicolumn{2}{c}{Derived Quantities} \\\\ \n") + tex.write("\\multicolumn{2}{c}{Derived Quantities} \\\\ \n") tex.write("\\hline\n") for param in derived_params: tex.write(publish_param(param)) diff --git a/tox.ini b/tox.ini index 67d010f84..5b18d6ba0 100644 --- a/tox.ini +++ b/tox.ini @@ -140,6 +140,7 @@ commands = black --check src tests examples [testenv:mypy] changedir = . description = use mypy +basepython = python3.12 deps = mypy==1.8.0 GitPython From 7ff5a752cfb5dd9f6a3bf81cd5df7ca948aadc25 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Sun, 25 Feb 2024 15:14:44 +0000 Subject: [PATCH 11/28] Finish updating string syntax --- src/pint/derived_quantities.py | 40 ++++++++----------- src/pint/models/binary_dd.py | 2 +- src/pint/models/binary_ddk.py | 6 +-- src/pint/models/pulsar_binary.py | 2 +- .../stand_alone_psr_binaries/DDK_model.py | 24 +++++------ .../stand_alone_psr_binaries/DDS_model.py | 2 +- .../stand_alone_psr_binaries/ELL1H_model.py | 4 +- src/pint/utils.py | 19 ++++----- 8 files changed, 46 insertions(+), 53 deletions(-) diff --git a/src/pint/derived_quantities.py b/src/pint/derived_quantities.py index d8cb66390..42c4dda04 100644 --- a/src/pint/derived_quantities.py +++ b/src/pint/derived_quantities.py @@ -120,8 +120,7 @@ def pferrs(porf, porferr, pdorfd=None, pdorfderr=None): return [1.0 / porf, porferr / porf**2.0] forperr = porferr / porf**2.0 fdorpderr = np.sqrt( - (4.0 * pdorfd**2.0 * porferr**2.0) / porf**6.0 - + pdorfderr**2.0 / porf**4.0 + (4.0 * pdorfd**2.0 * porferr**2.0) / porf**6.0 + pdorfderr**2.0 / porf**4.0 ) [forp, fdorpd] = p_to_f(porf, pdorfd) return [forp, forperr, fdorpd, fdorpderr] @@ -165,7 +164,7 @@ def pulsar_age(f: u.Hz, fdot: u.Hz / u.s, n=3, fo=1e99 * u.Hz): .. math:: - \\tau = \\frac{f}{(n-1)\dot f}\\left(1-\\left(\\frac{f}{f_0}\\right)^{n-1}\\right) + \tau = \frac{f}{(n-1)\dot f}\left(1-\left(\frac{f}{f_0}\right)^{n-1}\right) """ return (-f / ((n - 1.0) * fdot) * (1.0 - (f / fo) ** (n - 1.0))).to(u.yr) @@ -515,17 +514,12 @@ def companion_mass(pb: u.d, x: u.cm, i=60.0 * u.deg, mp=1.4 * u.solMass): # delta1 is always <0 # delta1 = 2 * b ** 3 - 9 * a * b * c + 27 * a ** 2 * d delta1 = ( - -2 * massfunct**3 - - 18 * a * mp * massfunct**2 - - 27 * a**2 * massfunct * mp**2 + -2 * massfunct**3 - 18 * a * mp * massfunct**2 - 27 * a**2 * massfunct * mp**2 ) # Q**2 is always > 0, so this is never a problem # in terms of complex numbers # Q = np.sqrt(delta1**2 - 4*delta0**3) - Q = np.sqrt( - 108 * a**3 * mp**3 * massfunct**3 - + 729 * a**4 * mp**4 * massfunct**2 - ) + Q = np.sqrt(108 * a**3 * mp**3 * massfunct**3 + 729 * a**4 * mp**4 * massfunct**2) # this could be + or - Q # pick the - branch since delta1 is <0 so that delta1 - Q is never near 0 Ccubed = 0.5 * (delta1 + Q) @@ -770,7 +764,7 @@ def sini(mp: u.Msun, mc: u.Msun, pb: u.d, x: u.cm): @u.quantity_input def dr(mp: u.Msun, mc: u.Msun, pb: u.d): - """Post-Keplerian Roemer delay term + r"""Post-Keplerian Roemer delay term dr (:math:`\delta_r`) is part of the relativistic deformation of the orbit @@ -820,9 +814,9 @@ def dr(mp: u.Msun, mc: u.Msun, pb: u.d): @u.quantity_input def dth(mp: u.Msun, mc: u.Msun, pb: u.d): - """Post-Keplerian Roemer delay term + r"""Post-Keplerian Roemer delay term - dth (:math:`\delta_{\\theta}`) is part of the relativistic deformation of the orbit + dth (:math:`\delta_{\theta}`) is part of the relativistic deformation of the orbit Parameters ---------- @@ -850,8 +844,8 @@ def dth(mp: u.Msun, mc: u.Msun, pb: u.d): .. math:: - \delta_{\\theta} = T_{\odot}^{2/3} \\left(\\frac{P_b}{2\pi}\\right)^{2/3} - \\frac{3.5 m_p^2+6 m_p m_c +2m_c^2}{(m_p+m_c)^{4/3}} + \delta_{\theta} = T_{\odot}^{2/3} \left(\frac{P_b}{2\pi}\right)^{2/3} + \frac{3.5 m_p^2+6 m_p m_c +2m_c^2}{(m_p+m_c)^{4/3}} with :math:`T_\odot = GM_\odot c^{-3}`. @@ -870,7 +864,7 @@ def dth(mp: u.Msun, mc: u.Msun, pb: u.d): @u.quantity_input def omdot_to_mtot(omdot: u.deg / u.yr, pb: u.d, e: u.dimensionless_unscaled): - """Determine total mass from Post-Keplerian longitude of periastron precession rate omdot, + r"""Determine total mass from Post-Keplerian longitude of periastron precession rate omdot, assuming general relativity. omdot (:math:`\dot \omega`) is the relativistic advance of periastron. It relates to the total @@ -904,8 +898,8 @@ def omdot_to_mtot(omdot: u.deg / u.yr, pb: u.d, e: u.dimensionless_unscaled): .. math:: - \dot \omega = 3T_{\odot}^{2/3} \\left(\\frac{P_b}{2\pi}\\right)^{-5/3} - \\frac{1}{1-e^2}(m_p+m_c)^{2/3} + \dot \omega = 3T_{\odot}^{2/3} \left(\frac{P_b}{2\pi}\right)^{-5/3} + \frac{1}{1-e^2}(m_p+m_c)^{2/3} to calculate :math:`m_{\\rm tot} = m_p + m_c`, with :math:`T_\odot = GM_\odot c^{-3}`. @@ -930,7 +924,7 @@ def omdot_to_mtot(omdot: u.deg / u.yr, pb: u.d, e: u.dimensionless_unscaled): @u.quantity_input(pb=u.d, mp=u.Msun, mc=u.Msun, i=u.deg) def a1sini(mp, mc, pb, i=90 * u.deg): - """Pulsar's semi-major axis. + r"""Pulsar's semi-major axis. The full semi-major axis is given by Kepler's third law. This is the projection (:math:`\sin i`) of just the pulsar's orbit (:math:`m_c/(m_p+m_c)` @@ -966,8 +960,8 @@ def a1sini(mp, mc, pb, i=90 * u.deg): .. math:: - \\frac{a_p \sin i}{c} = \\frac{m_c \sin i}{(m_p+m_c)^{2/3}} - G^{1/3}\\left(\\frac{P_b}{2\pi}\\right)^{2/3} + \frac{a_p \sin i}{c} = \frac{m_c \sin i}{(m_p+m_c)^{2/3}} + G^{1/3}\left(\frac{P_b}{2\pi}\right)^{2/3} More details in :ref:`Timing Models`. Also see [8]_ @@ -982,7 +976,7 @@ def a1sini(mp, mc, pb, i=90 * u.deg): @u.quantity_input def shklovskii_factor(pmtot: u.mas / u.yr, D: u.kpc): - """Compute magnitude of Shklovskii correction factor. + r"""Compute magnitude of Shklovskii correction factor. Computes the Shklovskii correction factor, as defined in Eq 8.12 of Lorimer & Kramer (2005) [10]_ This is the factor by which :math:`\dot P /P` is increased due to the transverse velocity. @@ -991,7 +985,7 @@ def shklovskii_factor(pmtot: u.mas / u.yr, D: u.kpc): .. math:: - \dot P_{\\rm intrinsic} = \dot P_{\\rm observed} - a_s P + \dot P_{\rm intrinsic} = \dot P_{\rm observed} - a_s P Parameters ---------- diff --git a/src/pint/models/binary_dd.py b/src/pint/models/binary_dd.py index b07df462f..adec86b4a 100644 --- a/src/pint/models/binary_dd.py +++ b/src/pint/models/binary_dd.py @@ -127,7 +127,7 @@ def validate(self): class BinaryDDS(BinaryDD): - """Damour and Deruelle model with alternate Shapiro delay parameterization. + r"""Damour and Deruelle model with alternate Shapiro delay parameterization. This extends the :class:`pint.models.binary_dd.BinaryDD` model with :math:`SHAPMAX = -\log(1-s)` instead of just :math:`s=\sin i`, which behaves better diff --git a/src/pint/models/binary_ddk.py b/src/pint/models/binary_ddk.py index 6b40f5ebc..1b5aa9e52 100644 --- a/src/pint/models/binary_ddk.py +++ b/src/pint/models/binary_ddk.py @@ -41,7 +41,7 @@ def _convert_kom(kom): class BinaryDDK(BinaryDD): - """Damour and Deruelle model with kinematics. + r"""Damour and Deruelle model with kinematics. This extends the :class:`pint.models.binary_dd.BinaryDD` model with "Shklovskii" and "Kopeikin" terms that account for the finite distance @@ -220,14 +220,14 @@ def validate(self): warnings.warn("Using A1DOT with a DDK model is not advised.") def alternative_solutions(self): - """Alternative Kopeikin solutions (potential local minima) + r"""Alternative Kopeikin solutions (potential local minima) There are 4 potential local minima for a DDK model where a1dot is the same These are given by where Eqn. 8 in Kopeikin (1996) is equal to the best-fit value. We first define the symmetry point where a1dot is zero (in equatorial coordinates): - :math:`KOM_0 = \\tan^{-1} (\mu_{\delta} / \mu_{\\alpha})` + :math:`KOM_0 = \tan^{-1} (\mu_{\delta} / \mu_{\alpha})` The solutions are then: diff --git a/src/pint/models/pulsar_binary.py b/src/pint/models/pulsar_binary.py index b27dc923c..dd9231797 100644 --- a/src/pint/models/pulsar_binary.py +++ b/src/pint/models/pulsar_binary.py @@ -38,7 +38,7 @@ class PulsarBinary(DelayComponent): - """Base class for binary models in PINT. + r"""Base class for binary models in PINT. This class provides a wrapper for internal classes that do the actual calculations. The calculations are done by the classes located in diff --git a/src/pint/models/stand_alone_psr_binaries/DDK_model.py b/src/pint/models/stand_alone_psr_binaries/DDK_model.py index 709033aae..be2e5f906 100644 --- a/src/pint/models/stand_alone_psr_binaries/DDK_model.py +++ b/src/pint/models/stand_alone_psr_binaries/DDK_model.py @@ -8,7 +8,7 @@ class DDKmodel(DDmodel): - """DDK model, a Kopeikin method corrected DD model. + r"""DDK model, a Kopeikin method corrected DD model. The main difference is that DDK model considers the effects on the pulsar binary parameters from the annual parallax of earth and the proper motion of the pulsar. @@ -155,7 +155,7 @@ def SINI(self, val): # Update binary parameters due to the pulser proper motion def delta_kin_proper_motion(self): - """The time dependent inclination angle + r"""The time dependent inclination angle (Kopeikin 1996 Eq 10): .. math:: @@ -231,7 +231,7 @@ def d_kin_d_par(self, par): return func() def delta_a1_proper_motion(self): - """The correction on a1 (projected semi-major axis) + r"""The correction on a1 (projected semi-major axis) due to the pulsar proper motion (Kopeikin 1996 Eq 8): @@ -289,7 +289,7 @@ def d_delta_a1_proper_motion_d_T0(self): return d_delta_a1_proper_motion_d_T0.to(a1.unit / self.T0.unit) def delta_omega_proper_motion(self): - """The correction on omega (Longitude of periastron) + r"""The correction on omega (Longitude of periastron) due to the pulsar proper motion (Kopeikin 1996 Eq 9): @@ -353,7 +353,7 @@ def d_delta_omega_proper_motion_d_T0(self): # Reference KOPEIKIN. 1995 Eq 18 -> Eq 19. def delta_I0(self): - """ + r""" :math:`\Delta_{I0}` Reference: (Kopeikin 1995 Eq 15) @@ -361,7 +361,7 @@ def delta_I0(self): return -self.obs_pos[:, 0] * self.sin_long + self.obs_pos[:, 1] * self.cos_long def delta_J0(self): - """ + r""" :math:`\Delta_{J0}` Reference: (Kopeikin 1995 Eq 16) @@ -373,19 +373,19 @@ def delta_J0(self): ) def delta_sini_parallax(self): - """Reference (Kopeikin 1995 Eq 18). Computes: + r"""Reference (Kopeikin 1995 Eq 18). Computes: .. math:: - x_{obs} = \\frac{a_p \sin(i)_{obs}}{c} + x_{obs} = \frac{a_p \sin(i)_{obs}}{c} Since :math:`a_p` and :math:`c` will not be changed by parallax: .. math:: - x_{obs} = \\frac{a_p}{c}(\sin(i)_{\\rm intrisic} + \delta_{\sin(i)}) + x_{obs} = \frac{a_p}{c}(\sin(i)_{\rm intrisic} + \delta_{\sin(i)}) - \delta_{\sin(i)} = \sin(i)_{\\rm intrisic} \\frac{\cot(i)_{\\rm intrisic}}{d} (\Delta_{I0} \sin KOM - \Delta_{J0} \cos KOM) + \delta_{\sin(i)} = \sin(i)_{\rm intrisic} \frac{\cot(i)_{\rm intrisic}}{d} (\Delta_{I0} \sin KOM - \Delta_{J0} \cos KOM) """ PX_kpc = self.PX.to(u.kpc, equivalencies=u.parallax()) @@ -518,9 +518,7 @@ def d_delta_omega_parallax_d_T0(self): PX_kpc = self.PX.to(u.kpc, equivalencies=u.parallax()) kom_projection = self.delta_I0() * self.cos_KOM + self.delta_J0() * self.sin_KOM d_kin_d_T0 = self.d_kin_d_par("T0") - d_delta_omega_d_T0 = ( - cos_kin / sin_kin**2 / PX_kpc * d_kin_d_T0 * kom_projection - ) + d_delta_omega_d_T0 = cos_kin / sin_kin**2 / PX_kpc * d_kin_d_T0 * kom_projection return d_delta_omega_d_T0.to( self.OM.unit / self.T0.unit, equivalencies=u.dimensionless_angles() ) diff --git a/src/pint/models/stand_alone_psr_binaries/DDS_model.py b/src/pint/models/stand_alone_psr_binaries/DDS_model.py index 1fc81fc56..db572f7f6 100644 --- a/src/pint/models/stand_alone_psr_binaries/DDS_model.py +++ b/src/pint/models/stand_alone_psr_binaries/DDS_model.py @@ -10,7 +10,7 @@ class DDSmodel(DDmodel): - """Damour and Deruelle model with alternate Shapiro delay parameterization. + r"""Damour and Deruelle model with alternate Shapiro delay parameterization. This extends the :class:`pint.models.binary_dd.BinaryDD` model with :math:`SHAPMAX = -\log(1-s)` instead of just :math:`s=\sin i`, which behaves better diff --git a/src/pint/models/stand_alone_psr_binaries/ELL1H_model.py b/src/pint/models/stand_alone_psr_binaries/ELL1H_model.py index 7a8de20e1..35930dbe8 100644 --- a/src/pint/models/stand_alone_psr_binaries/ELL1H_model.py +++ b/src/pint/models/stand_alone_psr_binaries/ELL1H_model.py @@ -9,7 +9,7 @@ class ELL1Hmodel(ELL1BaseModel): - """ELL1H pulsar binary model using H3, H4 or STIGMA as shapiro delay parameters. + r"""ELL1H pulsar binary model using H3, H4 or STIGMA as shapiro delay parameters. Note ---- @@ -21,7 +21,7 @@ class ELL1Hmodel(ELL1BaseModel): .. math:: - \\Delta_S = -2r \\left( \\frac{a_0}{2} + \\Sum_k (a_k \\cos k\\phi + b_k \\sin k \phi) \\right) + \Delta_S = -2r \left( \frac{a_0}{2} + \Sum_k (a_k \cos k\phi + b_k \sin k \phi) \right) The first two harmonics are generlly absorbed by the ELL1 Roemer delay. Thus, :class:`~pint.models.binary_ell1.BinaryELL1H` uses the series from the third diff --git a/src/pint/utils.py b/src/pint/utils.py index 1c0bcf21c..1331ef53b 100644 --- a/src/pint/utils.py +++ b/src/pint/utils.py @@ -1644,9 +1644,10 @@ def get_wavex_amps(model, index=None, quantity=False): model.components["WaveX"].get_prefix_mapping_component("WXSIN_").keys() ) if len(indices) == 1: - values = getattr( - model.components["WaveX"], f"WXSIN_{int(indices):04d}" - ), getattr(model.components["WaveX"], f"WXCOS_{int(indices):04d}") + values = ( + getattr(model.components["WaveX"], f"WXSIN_{int(indices):04d}"), + getattr(model.components["WaveX"], f"WXCOS_{int(indices):04d}"), + ) else: values = [ ( @@ -1657,8 +1658,9 @@ def get_wavex_amps(model, index=None, quantity=False): ] elif isinstance(index, (int, float, np.int64)): idx_rf = f"{int(index):04d}" - values = getattr(model.components["WaveX"], f"WXSIN_{idx_rf}"), getattr( - model.components["WaveX"], f"WXCOS_{idx_rf}" + values = ( + getattr(model.components["WaveX"], f"WXSIN_{idx_rf}"), + getattr(model.components["WaveX"], f"WXCOS_{idx_rf}"), ) elif isinstance(index, (list, set, np.ndarray)): idx_rf = [f"{int(idx):04d}" for idx in index] @@ -1784,12 +1786,12 @@ def weighted_mean(arrin, weights_in, inputmean=None, calcerr=False, sdev=False): def ELL1_check( A1: u.cm, E: u.dimensionless_unscaled, TRES: u.us, NTOA: int, outstring=True ): - """Check for validity of assumptions in ELL1 binary model + r"""Check for validity of assumptions in ELL1 binary model Checks whether the assumptions that allow ELL1 to be safely used are satisfied. To work properly, we should have: - :math:`asini/c e^4 \ll {\\rm timing precision} / \sqrt N_{\\rm TOA}` - or :math:`A1 E^4 \ll TRES / \sqrt N_{\\rm TOA}` + :math:`asini/c e^4 \ll {\rm timing precision} / \sqrt N_{\rm TOA}` + or :math:`A1 E^4 \ll TRES / \sqrt N_{\rm TOA}` since the ELL1 model now includes terms up to O(E^3) @@ -1810,7 +1812,6 @@ def ELL1_check( bool or str Returns True if ELL1 is safe to use, otherwise False. If outstring is True then returns a string summary instead. - """ lhs = A1 / const.c * E**4.0 rhs = TRES / np.sqrt(NTOA) From 63bf8665fe808615410135af44eb422eca3e8587 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Sun, 25 Feb 2024 15:26:58 +0000 Subject: [PATCH 12/28] Finish updating string syntax --- src/pint/eventstats.py | 2 +- src/pint/gridutils.py | 2 +- src/pint/logging.py | 3 ++- src/pint/polycos.py | 28 ++++++++++++++-------------- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/pint/eventstats.py b/src/pint/eventstats.py index b24c39aae..2fb68dbf1 100644 --- a/src/pint/eventstats.py +++ b/src/pint/eventstats.py @@ -47,7 +47,7 @@ def from_array(x): def sig2sigma(sig, two_tailed=True, logprob=False): - """Convert tail probability to "sigma" units. + r"""Convert tail probability to "sigma" units. Find the value of the argument for the normal distribution beyond which the integrated tail probability is sig. Note that the default is to interpret diff --git a/src/pint/gridutils.py b/src/pint/gridutils.py index 4a41d1c1d..4e2ec06d7 100644 --- a/src/pint/gridutils.py +++ b/src/pint/gridutils.py @@ -164,7 +164,7 @@ def grid_chisq( printprogress=True, **fitargs, ): - """Compute chisq over a grid of parameters + r"""Compute chisq over a grid of parameters Parameters ---------- diff --git a/src/pint/logging.py b/src/pint/logging.py index 5187091b3..8cd4a2cad 100644 --- a/src/pint/logging.py +++ b/src/pint/logging.py @@ -124,12 +124,13 @@ def showwarning(message, category, filename, lineno, file=None, line=None): class LogFilter: """Custom logging filter for ``loguru``. + Define some messages that are never seen (e.g., Deprecation Warnings). Others that will only be seen once. Filtering of those is done on the basis of regular expressions. """ def __init__(self, onlyonce=None, never=None, onlyonce_level="INFO"): - """ + r""" Define regexs for messages that will only be seen once. Use ``\S+`` for a variable that might change. If a message comes through with a new value for that variable, it will be seen. diff --git a/src/pint/polycos.py b/src/pint/polycos.py index 686576b9b..0542d3304 100644 --- a/src/pint/polycos.py +++ b/src/pint/polycos.py @@ -1,4 +1,4 @@ -"""Polynomial coefficients for phase prediction +r"""Polynomial coefficients for phase prediction Polycos designed to predict the pulsar's phase and pulse-period over a given interval using polynomial expansions. @@ -7,11 +7,11 @@ .. math:: - \\Delta T = 1440(T-T_{\\rm mid}) + \Delta T = 1440(T-T_{\rm mid}) - \\phi = \\phi_0 + 60 \Delta T f_0 + COEFF[1] + COEFF[2] \\Delta T + COEFF[3] \\Delta T^2 + \\ldots + \phi = \phi_0 + 60 \Delta T f_0 + COEFF[1] + COEFF[2] \Delta T + COEFF[3] \Delta T^2 + \ldots - f({\\rm Hz}) = f_0 + \\frac{1}{60}\\left( COEFF[2] + 2 COEFF[3] \Delta T + 3 COEFF[4] \Delta T^2 + \\ldots \\right) + f({\rm Hz}) = f_0 + \frac{1}{60}\left( COEFF[2] + 2 COEFF[3] \Delta T + 3 COEFF[4] \Delta T^2 + \ldots \right) Examples -------- @@ -228,7 +228,7 @@ def evalfreqderiv(self, t): # Read polycos file data to table def tempo_polyco_table_reader(filename): - """Read tempo style polyco file to an astropy table. + r"""Read tempo style polyco file to an astropy table. Tempo style: The polynomial ephemerides are written to file 'polyco.dat'. Entries are listed sequentially within the file. The file format is:: @@ -262,11 +262,11 @@ def tempo_polyco_table_reader(filename): .. math:: - \\Delta T = 1440(T-T_{\\rm mid}) + \Delta T = 1440(T-T_{\rm mid}) - \\phi = \\phi_0 + 60 \Delta T f_0 + COEFF[1] + COEFF[2] \\Delta T + COEFF[3] \\Delta T^2 + \\ldots + \phi = \phi_0 + 60 \Delta T f_0 + COEFF[1] + COEFF[2] \Delta T + COEFF[3] \Delta T^2 + \ldots - f({\\rm Hz}) = f_0 + \\frac{1}{60}\\left( COEFF[2] + 2 COEFF[3] \Delta T + 3 COEFF[4] \Delta T^2 + \\ldots \\right) + f({\rm Hz}) = f_0 + \frac{1}{60}\left( COEFF[2] + 2 COEFF[3] \Delta T + 3 COEFF[4] \Delta T^2 + \ldots \right) Parameters ---------- @@ -356,7 +356,7 @@ def tempo_polyco_table_reader(filename): def tempo_polyco_table_writer(polycoTable, filename="polyco.dat"): - """Write tempo style polyco file from an astropy table. + r"""Write tempo style polyco file from an astropy table. Tempo style polyco file: The polynomial ephemerides are written to file 'polyco.dat'. Entries @@ -389,11 +389,11 @@ def tempo_polyco_table_writer(polycoTable, filename="polyco.dat"): .. math:: - \\Delta T = 1440(T-T_{\\rm mid}) + \Delta T = 1440(T-T_{\rm mid}) - \\phi = \\phi_0 + 60 \Delta T f_0 + COEFF[1] + COEFF[2] \\Delta T + COEFF[3] \\Delta T^2 + \\ldots + \phi = \phi_0 + 60 \Delta T f_0 + COEFF[1] + COEFF[2] \Delta T + COEFF[3] \Delta T^2 + \ldots - f({\\rm Hz}) = f_0 + \\frac{1}{60}\\left( COEFF[2] + 2 COEFF[3] \Delta T + 3 COEFF[4] \Delta T^2 + \\ldots \\right) + f({\rm Hz}) = f_0 + \frac{1}{60}\left( COEFF[2] + 2 COEFF[3] \Delta T + 3 COEFF[4] \Delta T^2 + \ldots \right) Parameters ---------- @@ -918,7 +918,7 @@ def eval_phase(self, t): return self.eval_abs_phase(t).frac def eval_abs_phase(self, t): - """ + r""" Polyco evaluate absolute phase for a time array. Parameters @@ -937,7 +937,7 @@ def eval_abs_phase(self, t): .. math:: - \\phi = \\phi_0 + 60 \\Delta T f_0 + COEFF[1] + COEFF[2] \Delta T + COEFF[3] \Delta T^2 + \\ldots + \phi = \phi_0 + 60 \Delta T f_0 + COEFF[1] + COEFF[2] \Delta T + COEFF[3] \Delta T^2 + \ldots Calculation done using :meth:`pint.polycos.PolycoEntry.evalabsphase` """ From 8df8971df7c5e63b8247a6fe17fb773581fe4ee5 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Sun, 25 Feb 2024 15:28:51 +0000 Subject: [PATCH 13/28] Fighting with black --- src/pint/derived_quantities.py | 12 +++++++++--- .../models/stand_alone_psr_binaries/DDK_model.py | 4 +++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/pint/derived_quantities.py b/src/pint/derived_quantities.py index 42c4dda04..0d3646078 100644 --- a/src/pint/derived_quantities.py +++ b/src/pint/derived_quantities.py @@ -120,7 +120,8 @@ def pferrs(porf, porferr, pdorfd=None, pdorfderr=None): return [1.0 / porf, porferr / porf**2.0] forperr = porferr / porf**2.0 fdorpderr = np.sqrt( - (4.0 * pdorfd**2.0 * porferr**2.0) / porf**6.0 + pdorfderr**2.0 / porf**4.0 + (4.0 * pdorfd**2.0 * porferr**2.0) / porf**6.0 + + pdorfderr**2.0 / porf**4.0 ) [forp, fdorpd] = p_to_f(porf, pdorfd) return [forp, forperr, fdorpd, fdorpderr] @@ -514,12 +515,17 @@ def companion_mass(pb: u.d, x: u.cm, i=60.0 * u.deg, mp=1.4 * u.solMass): # delta1 is always <0 # delta1 = 2 * b ** 3 - 9 * a * b * c + 27 * a ** 2 * d delta1 = ( - -2 * massfunct**3 - 18 * a * mp * massfunct**2 - 27 * a**2 * massfunct * mp**2 + -2 * massfunct**3 + - 18 * a * mp * massfunct**2 + - 27 * a**2 * massfunct * mp**2 ) # Q**2 is always > 0, so this is never a problem # in terms of complex numbers # Q = np.sqrt(delta1**2 - 4*delta0**3) - Q = np.sqrt(108 * a**3 * mp**3 * massfunct**3 + 729 * a**4 * mp**4 * massfunct**2) + Q = np.sqrt( + 108 * a**3 * mp**3 * massfunct**3 + + 729 * a**4 * mp**4 * massfunct**2 + ) # this could be + or - Q # pick the - branch since delta1 is <0 so that delta1 - Q is never near 0 Ccubed = 0.5 * (delta1 + Q) diff --git a/src/pint/models/stand_alone_psr_binaries/DDK_model.py b/src/pint/models/stand_alone_psr_binaries/DDK_model.py index be2e5f906..a586400b0 100644 --- a/src/pint/models/stand_alone_psr_binaries/DDK_model.py +++ b/src/pint/models/stand_alone_psr_binaries/DDK_model.py @@ -518,7 +518,9 @@ def d_delta_omega_parallax_d_T0(self): PX_kpc = self.PX.to(u.kpc, equivalencies=u.parallax()) kom_projection = self.delta_I0() * self.cos_KOM + self.delta_J0() * self.sin_KOM d_kin_d_T0 = self.d_kin_d_par("T0") - d_delta_omega_d_T0 = cos_kin / sin_kin**2 / PX_kpc * d_kin_d_T0 * kom_projection + d_delta_omega_d_T0 = ( + cos_kin / sin_kin**2 / PX_kpc * d_kin_d_T0 * kom_projection + ) return d_delta_omega_d_T0.to( self.OM.unit / self.T0.unit, equivalencies=u.dimensionless_angles() ) From c920ce7c6ff2012e69af875a683aa3fee6ade493 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Tue, 27 Feb 2024 19:00:13 +0000 Subject: [PATCH 14/28] Try to fix mypy --- .github/workflows/ci_test.yml | 6 ++-- mypy.ini | 64 +++++++++++++++++++++++++++++++++++ setup.cfg | 64 ----------------------------------- src/pint/logging.py | 5 +-- tox.ini | 5 +-- 5 files changed, 73 insertions(+), 71 deletions(-) create mode 100644 mypy.ini diff --git a/.github/workflows/ci_test.yml b/.github/workflows/ci_test.yml index e296c0769..74c8ba9f5 100644 --- a/.github/workflows/ci_test.yml +++ b/.github/workflows/ci_test.yml @@ -31,13 +31,13 @@ jobs: tox_env: 'black' - os: ubuntu-latest python: '3.12' - tox_env: 'py312-test-cov' + tox_env: 'mypy' - os: ubuntu-latest python: '3.12' - tox_env: 'notebooks' + tox_env: 'py312-test-cov' - os: ubuntu-latest python: '3.12' - tox_env: 'mypy' + tox_env: 'notebooks' # - os: ubuntu-latest # python: '3.8' # tox_env: 'docs' diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..f90950d61 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,64 @@ + +[mypy] +warn_unused_configs = True +files = src/pint + +[mypy-pint.templates.*] +ignore_errors = True + +[mypy-pint.derived_quantities] +ignore_errors = True + +[mypy-pint.observatory.*] +ignore_errors = True + +[mypy-pint.models.stand_alone_psr_binaries.*] +ignore_errors = True + +[mypy-pint.models.timing_model] +ignore_errors = True + +[mypy-pint.models.pulsar_binary] +ignore_errors = True + +[mypy-pint.fitter] +ignore_errors = True + +[mypy-pint.polycos] +ignore_errors = True + +[mypy-pint.output.publish] +ignore_errors = True + +[mypy-pint.gridutils] +ignore_errors = True + +[mypy-pint.scripts.*] +ignore_errors = True + +[mypy-pint.pintk.plk] +ignore_errors = True + +[mypy-astropy.*] +ignore_missing_imports = True + +[mypy-erfa] +ignore_missing_imports = True + +[mypy-emcee] +ignore_missing_imports = True + +[mypy-jplephem] +ignore_missing_imports = True + +[mypy-numdifftools] +ignore_missing_imports = True + +[mypy-pylab] +ignore_missing_imports = True + +[mypy-scipy.*] +ignore_missing_imports = True + +[mypy-uncertainties] +ignore_missing_imports = True \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 9e841acc0..ad2570a5f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -129,67 +129,3 @@ skip_glob = src/pint/extern/* include_trailing_comma = True combine_as_imports = True - -[mypy] -warn_unused_configs = True -files = src/pint - -[mypy-pint.templates.*] -ignore_errors = True - -[mypy-pint.derived_quantities] -ignore_errors = True - -[mypy-pint.observatory.*] -ignore_errors = True - -[mypy-pint.models.stand_alone_psr_binaries.*] -ignore_errors = True - -[mypy-pint.models.timing_model] -ignore_errors = True - -[mypy-pint.models.pulsar_binary] -ignore_errors = True - -[mypy-pint.fitter] -ignore_errors = True - -[mypy-pint.polycos] -ignore_errors = True - -[mypy-pint.output.publish] -ignore_errors = True - -[mypy-pint.gridutils] -ignore_errors = True - -[mypy-pint.scripts.*] -ignore_errors = True - -[mypy-pint.pintk.plk] -ignore_errors = True - -[mypy-astropy.*] -ignore_missing_imports = True - -[mypy-erfa] -ignore_missing_imports = True - -[mypy-emcee] -ignore_missing_imports = True - -[mypy-jplephem] -ignore_missing_imports = True - -[mypy-numdifftools] -ignore_missing_imports = True - -[mypy-pylab] -ignore_missing_imports = True - -[mypy-scipy.*] -ignore_missing_imports = True - -[mypy-uncertainties] -ignore_missing_imports = True \ No newline at end of file diff --git a/src/pint/logging.py b/src/pint/logging.py index 8cd4a2cad..92827e192 100644 --- a/src/pint/logging.py +++ b/src/pint/logging.py @@ -53,9 +53,10 @@ import re import sys import warnings -from loguru import logger as log +from typing import Dict from erfa import ErfaWarning +from loguru import logger as log __all__ = ["LogFilter", "setup", "format", "levels", "get_level"] @@ -72,7 +73,7 @@ # https://loguru.readthedocs.io/en/stable/api/logger.html#color showwarning_ = warnings.showwarning -warning_onceregistry: dict[tuple[str, str], int] = {} +warning_onceregistry: Dict[tuple[str, str], int] = {} # basic loguru level definitions from: # https://loguru.readthedocs.io/en/stable/api/logger.html diff --git a/tox.ini b/tox.ini index 5b18d6ba0..e9126a3fa 100644 --- a/tox.ini +++ b/tox.ini @@ -132,9 +132,10 @@ commands = sphinx-build -d "{toxworkdir}/docs_doctree" . "{toxworkdir}/docs_out" skip_install = true changedir = . description = use black +basepython = python3.12 deps = black~=23.0 -commands = black --check src tests examples +commands = black src tests examples {posargs:--check} [testenv:mypy] @@ -146,4 +147,4 @@ deps = GitPython types-setuptools types-tqdm -commands = mypy \ No newline at end of file +commands = mypy --no-incremental {posargs} \ No newline at end of file From d993f41f6ed797ba6b3ed01123ca46bef50af5a3 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Tue, 27 Feb 2024 20:06:07 +0000 Subject: [PATCH 15/28] Fix typing in simulation.py --- .pre-commit-config.yaml | 2 +- src/pint/fitter.py | 49 ++++++++++++++++++++++++++++++++++++----- src/pint/simulation.py | 12 +++++----- 3 files changed, 50 insertions(+), 13 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6c2c4b8cc..de5ce1e24 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,6 +6,6 @@ repos: - id: check-merge-conflict - id: check-symlinks - repo: https://github.com/psf/black - rev: 24.2.0 + rev: 23.12.1 hooks: - id: black diff --git a/src/pint/fitter.py b/src/pint/fitter.py index c34ed3332..16430d438 100644 --- a/src/pint/fitter.py +++ b/src/pint/fitter.py @@ -60,6 +60,7 @@ import contextlib import copy +from typing import List, Optional from warnings import warn import astropy.units as u @@ -70,8 +71,9 @@ from numdifftools import Hessian import pint -import pint.utils import pint.derived_quantities +import pint.utils +import pint.models.timing_model from pint.models.parameter import ( AngleParameter, boolParameter, @@ -90,7 +92,6 @@ from pint.toa import TOAs from pint.utils import FTest, normalize_designmatrix - __all__ = [ "Fitter", "WLSFitter", @@ -221,7 +222,41 @@ class Fitter: ``GLSFitter`` is used to compute ``chi2`` for appropriate Residuals objects. """ - def __init__(self, toas, model, track_mode=None, residuals=None): + toas: TOAs + """TOAs to fit.""" + model_init: pint.models.timing_model.TimingModel + """Initial timing model the Fitter was created with.""" + track_mode: Optional[str] + """How to handle phase wrapping. + + This is used when creating :class:`pint.residuals.Residuals` + objects, and its meaning is defined there. + """ + resids_init: Residuals + """Initial residuals with respect to the timing model.""" + model: pint.models.timing_model.TimingModel + """Current timing model in use by the Fitter.""" + fitresult: List + method: Optional[str] + is_wideband: bool + converged: bool + parameter_covariance_matrix: CovarianceMatrix + """The covariance matrix of the model parameters after fitting. + + This attribute may not exist if the fitter has not been run + (some subclasses of Fitter don't compute this matrix except + as part of the fit, and don't create the attribute). + """ + fac: np.ndarray + """Scaling factors applied to the columns(?) of the design matrix.""" + + def __init__( + self, + toas: TOAs, + model: pint.models.TimingModel, + track_mode: Optional[str] = None, + residuals: Optional[Residuals] = None, + ): if not set(model.free_params).issubset(model.fittable_params): free_unfittable_params = set(model.free_params).difference( model.fittable_params @@ -490,9 +525,11 @@ def get_derived_params(self, returndict=False): """ return self.model.get_derived_params( - rms=self.resids.toa.rms_weighted() - if self.is_wideband - else self.resids.rms_weighted(), + rms=( + self.resids.toa.rms_weighted() + if self.is_wideband + else self.resids.rms_weighted() + ), ntoas=self.toas.ntoas, returndict=returndict, ) diff --git a/src/pint/simulation.py b/src/pint/simulation.py index 958e6200d..96ffa97f2 100644 --- a/src/pint/simulation.py +++ b/src/pint/simulation.py @@ -1,19 +1,19 @@ """Functions related to simulating TOAs and models """ +import pathlib from collections import OrderedDict from copy import deepcopy -from typing import Optional, List, Union -import pathlib +from typing import List, Optional, Tuple, Union import astropy.units as u import numpy as np -from loguru import logger as log from astropy import time +from loguru import logger as log +import pint.fitter import pint.residuals import pint.toa -import pint.fitter from pint.observatory import bipm_default, get_observatory __all__ = [ @@ -566,7 +566,7 @@ def calculate_random_models( keep_models: bool = True, return_time: bool = False, params: str = "all", -) -> (np.ndarray, Optional[list]): +) -> Tuple[np.ndarray, Optional[list]]: """ Calculates random models based on the covariance matrix of the `fitter` object. @@ -695,7 +695,7 @@ def _get_freqs_and_times( ntoas: int, freqs: u.Quantity, multi_freqs_in_epoch: bool = True, -) -> (Union[float, u.Quantity, time.Time], np.ndarray): +) -> Tuple[Union[np.ndarray, u.Quantity, time.Time], np.ndarray]: freqs = np.atleast_1d(freqs) assert ( len(freqs.shape) == 1 and len(freqs) <= ntoas From 818f7cfcc50fc33a0c9dd001211b2e6c2918a874 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Tue, 27 Feb 2024 20:10:58 +0000 Subject: [PATCH 16/28] fix syntax for 3.8 --- src/pint/logging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pint/logging.py b/src/pint/logging.py index 92827e192..f1d367f8c 100644 --- a/src/pint/logging.py +++ b/src/pint/logging.py @@ -53,7 +53,7 @@ import re import sys import warnings -from typing import Dict +from typing import Dict, Tuple from erfa import ErfaWarning from loguru import logger as log @@ -73,7 +73,7 @@ # https://loguru.readthedocs.io/en/stable/api/logger.html#color showwarning_ = warnings.showwarning -warning_onceregistry: Dict[tuple[str, str], int] = {} +warning_onceregistry: Dict[Tuple[str, str], int] = {} # basic loguru level definitions from: # https://loguru.readthedocs.io/en/stable/api/logger.html From f3fe13e8b50cbd59761ca3d614f2bfd2f6d1c0f7 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Tue, 27 Feb 2024 22:23:31 +0000 Subject: [PATCH 17/28] Add stricter checking --- mypy.ini | 11 ++++++++++- src/pint/simulation.py | 4 ++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/mypy.ini b/mypy.ini index f90950d61..131ac95cf 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,6 +2,11 @@ [mypy] warn_unused_configs = True files = src/pint +# 3.8 causes a problem with some versions of matplotlib +python_version = 3.9 +warn_unreachable = True +warn_return_any = True +local_partial_types = True [mypy-pint.templates.*] ignore_errors = True @@ -39,6 +44,9 @@ ignore_errors = True [mypy-pint.pintk.plk] ignore_errors = True +[mypy-pint.extern.*] +ignore_errors = True + [mypy-astropy.*] ignore_missing_imports = True @@ -61,4 +69,5 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-uncertainties] -ignore_missing_imports = True \ No newline at end of file +ignore_missing_imports = True + diff --git a/src/pint/simulation.py b/src/pint/simulation.py index 96ffa97f2..02ad66e13 100644 --- a/src/pint/simulation.py +++ b/src/pint/simulation.py @@ -51,7 +51,7 @@ def zero_residuals( 1 nanosecond if operating in full precision or 5 us if not. """ ts.compute_pulse_numbers(model) - maxresid = None + maxresid: Optional[float] = None if tolerance is None: tolerance = 1 * u.ns if pint.utils.check_longdouble_precision() else 5 * u.us for i in range(maxiter): @@ -566,7 +566,7 @@ def calculate_random_models( keep_models: bool = True, return_time: bool = False, params: str = "all", -) -> Tuple[np.ndarray, Optional[list]]: +) -> Union[Tuple[np.ndarray, Optional[list]], np.ndarray]: """ Calculates random models based on the covariance matrix of the `fitter` object. From 40865341c901c09fcd63b0f5710e4e3883653c5b Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Wed, 28 Feb 2024 18:18:41 +0000 Subject: [PATCH 18/28] additional restrictions --- mypy.ini | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mypy.ini b/mypy.ini index 131ac95cf..a151f237d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -7,6 +7,8 @@ python_version = 3.9 warn_unreachable = True warn_return_any = True local_partial_types = True +no_implicit_reexport = True +strict_equality = True [mypy-pint.templates.*] ignore_errors = True From c66739619fc25278ef1ab5f027d9b39a394acce4 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Wed, 28 Feb 2024 18:45:59 +0000 Subject: [PATCH 19/28] use overload to specify more closely --- src/pint/simulation.py | 78 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 74 insertions(+), 4 deletions(-) diff --git a/src/pint/simulation.py b/src/pint/simulation.py index 02ad66e13..b8be080eb 100644 --- a/src/pint/simulation.py +++ b/src/pint/simulation.py @@ -4,7 +4,7 @@ import pathlib from collections import OrderedDict from copy import deepcopy -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Dict, overload import astropy.units as u import numpy as np @@ -33,7 +33,7 @@ def zero_residuals( subtract_mean: bool = True, maxiter: int = 10, tolerance: Optional[u.Quantity] = None, -): +) -> None: """Use a model to adjust a TOAs object, setting residuals to 0 iteratively. Parameters @@ -77,7 +77,7 @@ def get_fake_toa_clock_versions( model: pint.models.timing_model.TimingModel, include_bipm: bool = False, include_gps: bool = True, -) -> dict: +) -> Dict[str, Union[bool, str]]: """Get the clock settings (corrections, etc) for fake TOAs Parameters @@ -220,6 +220,54 @@ def update_fake_dms( return toas +@overload +def make_fake_toas_uniform( + startMJD: float, + endMJD: float, + ntoas: int, + model: pint.models.timing_model.TimingModel, + fuzz: u.Quantity = 0, + freq: u.Quantity = 1400 * u.MHz, + obs: str = "GBT", + error: u.Quantity = 1 * u.us, + add_noise: bool = False, + add_correlated_noise: bool = False, + wideband: bool = False, + wideband_dm_error: u.Quantity = 1e-4 * pint.dmu, + name: str = "fake", + include_bipm: bool = False, + include_gps: bool = True, + multi_freqs_in_epoch: bool = False, + flags: Optional[dict] = None, + subtract_mean: bool = True, +) -> pint.toa.TOAs: + ... + + +@overload +def make_fake_toas_uniform( + startMJD: u.Quantity, + endMJD: u.Quantity, + ntoas: int, + model: pint.models.timing_model.TimingModel, + fuzz: u.Quantity = 0, + freq: u.Quantity = 1400 * u.MHz, + obs: str = "GBT", + error: u.Quantity = 1 * u.us, + add_noise: bool = False, + add_correlated_noise: bool = False, + wideband: bool = False, + wideband_dm_error: u.Quantity = 1e-4 * pint.dmu, + name: str = "fake", + include_bipm: bool = False, + include_gps: bool = True, + multi_freqs_in_epoch: bool = False, + flags: Optional[dict] = None, + subtract_mean: bool = True, +) -> pint.toa.TOAs: + ... + + def make_fake_toas_uniform( startMJD: Union[float, u.Quantity, time.Time], endMJD: Union[float, u.Quantity, time.Time], @@ -566,7 +614,7 @@ def calculate_random_models( keep_models: bool = True, return_time: bool = False, params: str = "all", -) -> Union[Tuple[np.ndarray, Optional[list]], np.ndarray]: +) -> Union[Tuple[np.ndarray, list], np.ndarray]: """ Calculates random models based on the covariance matrix of the `fitter` object. @@ -689,6 +737,28 @@ def calculate_random_models( return (dphase, random_models) if keep_models else dphase +@overload +def _get_freqs_and_times( + start: float, + end: float, + ntoas: int, + freqs: u.Quantity, + multi_freqs_in_epoch: bool = True, +) -> Tuple[np.ndarray, np.ndarray]: + ... + + +@overload +def _get_freqs_and_times( + start: time.Time, + end: time.Time, + ntoas: int, + freqs: u.Quantity, + multi_freqs_in_epoch: bool = True, +) -> Tuple[time.Time, np.ndarray]: + ... + + def _get_freqs_and_times( start: Union[float, u.Quantity, time.Time], end: Union[float, u.Quantity, time.Time], From f71cb193f98dfa6e1a0d19f0ce4673aa7358c4ee Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Wed, 28 Feb 2024 19:04:39 +0000 Subject: [PATCH 20/28] Make mypy a little stricter on observatory --- mypy.ini | 5 +++-- src/pint/observatory/__init__.py | 9 ++++++--- src/pint/observatory/topo_obs.py | 5 ++++- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/mypy.ini b/mypy.ini index a151f237d..5ff5f72c3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,7 +17,9 @@ ignore_errors = True ignore_errors = True [mypy-pint.observatory.*] -ignore_errors = True +allow_untyped_globals = True +warn_unreachable = False +warn_return_any = False [mypy-pint.models.stand_alone_psr_binaries.*] ignore_errors = True @@ -72,4 +74,3 @@ ignore_missing_imports = True [mypy-uncertainties] ignore_missing_imports = True - diff --git a/src/pint/observatory/__init__.py b/src/pint/observatory/__init__.py index bccb40d53..8564e2a38 100644 --- a/src/pint/observatory/__init__.py +++ b/src/pint/observatory/__init__.py @@ -28,7 +28,7 @@ from copy import deepcopy from io import StringIO from pathlib import Path -from typing import Optional, Union, List, Dict, Literal +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union import astropy.coordinates import astropy.time @@ -39,7 +39,10 @@ from pint.config import runtimefile from pint.pulsar_mjd import Time -from pint.utils import interesting_lines, PosVel +from pint.utils import PosVel, interesting_lines + +if TYPE_CHECKING: + from pint.observatory.clock_file import ClockFile # Include any files that define observatories here. This will start # with the standard distribution files, then will read any system- or @@ -830,7 +833,7 @@ def update_clock_files(bipm_versions: Optional[List[str]] = None) -> None: # Both topo_obs and special_locations need this def find_clock_file( name: str, - format: Literal["tempo", "tempo2"], + format: str, bogus_last_correction: bool = False, url_base: Optional[str] = None, clock_dir: Union[str, Path, None] = None, diff --git a/src/pint/observatory/topo_obs.py b/src/pint/observatory/topo_obs.py index 3e47d9b5b..7af88807d 100644 --- a/src/pint/observatory/topo_obs.py +++ b/src/pint/observatory/topo_obs.py @@ -359,7 +359,10 @@ def _clock(self) -> list: cf, format=self.clock_fmt, clock_dir=self.clock_dir, - **kwargs, + # mypy is unhappy about passing in a dict as **kwargs + # which is fair enough since it can't check the keys + # are valid arguments. + **kwargs, # type: ignore ) ) return clock From 29743b2bb43f779c8c90da9da2cd919e7344c6f7 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Wed, 28 Feb 2024 19:25:34 +0000 Subject: [PATCH 21/28] get timing_models passing --- mypy.ini | 4 +-- src/pint/models/timing_model.py | 60 ++++++++++++++++----------------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/mypy.ini b/mypy.ini index 5ff5f72c3..3c7f64245 100644 --- a/mypy.ini +++ b/mypy.ini @@ -25,7 +25,7 @@ warn_return_any = False ignore_errors = True [mypy-pint.models.timing_model] -ignore_errors = True +; ignore_errors = True [mypy-pint.models.pulsar_binary] ignore_errors = True @@ -72,5 +72,5 @@ ignore_missing_imports = True [mypy-scipy.*] ignore_missing_imports = True -[mypy-uncertainties] +[mypy-uncertainties.*] ignore_missing_imports = True diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py index 06e3e1ead..7b8807624 100644 --- a/src/pint/models/timing_model.py +++ b/src/pint/models/timing_model.py @@ -24,53 +24,53 @@ :func:`~pint.models.model_builder.get_model`. See :ref:`Timing Models` for more details on how PINT's timing models work. - """ import abc +import contextlib import copy import inspect -import contextlib from collections import OrderedDict, defaultdict from functools import wraps +from typing import Dict from warnings import warn -from uncertainties import ufloat +import astropy.coordinates as coords import astropy.time as time -from astropy import units as u, constants as c import numpy as np +from astropy import constants as c +from astropy import units as u from astropy.utils.decorators import lazyproperty -import astropy.coordinates as coords -from pint.pulsar_ecliptic import OBL, PulsarEcliptic -from scipy.optimize import brentq from loguru import logger as log +from scipy.optimize import brentq +from uncertainties import ufloat import pint +from pint.derived_quantities import dispersion_slope from pint.models.parameter import ( - _parfile_formats, AngleParameter, MJDParameter, Parameter, + _parfile_formats, boolParameter, floatParameter, funcParameter, intParameter, maskParameter, - strParameter, prefixParameter, + strParameter, ) from pint.phase import Phase +from pint.pulsar_ecliptic import OBL, PulsarEcliptic from pint.toa import TOAs from pint.utils import ( PrefixError, - split_prefixed_name, - open_or_use, colorize, + open_or_use, + split_prefixed_name, xxxselections, ) -from pint.derived_quantities import dispersion_slope - __all__ = [ "DEFAULT_ORDER", @@ -490,10 +490,10 @@ def num_components_of_type(type): ), "Model can have at most one solar wind dispersion component." from pint.models.dispersion_model import DispersionDMX + from pint.models.dmwavex import DMWaveX + from pint.models.noise_model import PLDMNoise, PLRedNoise from pint.models.wave import Wave from pint.models.wavex import WaveX - from pint.models.dmwavex import DMWaveX - from pint.models.noise_model import PLRedNoise, PLDMNoise if num_components_of_type((DispersionDMX, PLDMNoise, DMWaveX)) > 1: log.warning( @@ -612,7 +612,9 @@ def free_params(self): """ return [p for p in self.params if not getattr(self, p).frozen] - @free_params.setter + # mypy doesn't understand the decorator syntax here + # maybe we'd need to express the type of property_exists better? + @free_params.setter # type: ignore def free_params(self, params): params_true = {self.match_param_aliases(p) for p in params} for p in self.params: @@ -635,10 +637,8 @@ def fittable_params(self): p in self.phase_deriv_funcs or p in self.delay_deriv_funcs or ( - ( - hasattr(self, "toasigma_deriv_funcs") - and p in self.toasigma_deriv_funcs - ) + hasattr(self, "toasigma_deriv_funcs") + and p in self.toasigma_deriv_funcs ) or (hasattr(self[p], "prefix") and self[p].prefix == "ECORR") ) @@ -3127,7 +3127,9 @@ def get_derived_params(self, rms=None, ntoas=None, returndict=False): ) s += "Conversion from ELL1 parameters:\n" ecc = um.sqrt(eps1**2 + eps2**2) - s += "ECC = {:P}\n".format(ecc) + # mypy does not know about uncertainties introducing a new + # format code, so we have to tell it to ignore this line + s += "ECC = {:P}\n".format(ecc) # type: ignore outdict["ECC"] = ecc om = um.atan2(eps1, eps2) * 180.0 / np.pi if om < 0.0: @@ -3181,14 +3183,12 @@ def get_derived_params(self, rms=None, ntoas=None, returndict=False): omdot = self.OMDOT.as_ufloat(u.rad / u.s) e = ecc if ell1 else self.ECC.as_ufloat() mt = ( - ( - omdot - / ( - 3 - * (c.G * u.Msun / c.c**3).to_value(u.s) ** (2.0 / 3) - * ((pb * 86400 / 2 / np.pi)) ** (-5.0 / 3) - * (1 - e**2) ** -1 - ) + omdot + / ( + 3 + * (c.G * u.Msun / c.c**3).to_value(u.s) ** (2.0 / 3) + * (pb * 86400 / 2 / np.pi) ** (-5.0 / 3) + * (1 - e**2) ** -1 ) ) ** (3.0 / 2) s += f"Total mass, assuming GR, from OMDOT is {mt:SP} Msun\n" @@ -3254,7 +3254,7 @@ class Component(metaclass=ModelMeta): invalid parameter values are chosen. """ - component_types = {} + component_types: Dict[str, ModelMeta] = {} def __init__(self): self.params = [] From 1befd58046e27ba3591022bf5b3d4816efe560da Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Wed, 28 Feb 2024 19:33:23 +0000 Subject: [PATCH 22/28] Get fitter working --- mypy.ini | 5 +-- src/pint/fitter.py | 70 +++++-------------------------------- src/pint/models/__init__.py | 39 +++++++++++++++++++++ 3 files changed, 48 insertions(+), 66 deletions(-) diff --git a/mypy.ini b/mypy.ini index 3c7f64245..e3cf3c787 100644 --- a/mypy.ini +++ b/mypy.ini @@ -24,14 +24,11 @@ warn_return_any = False [mypy-pint.models.stand_alone_psr_binaries.*] ignore_errors = True -[mypy-pint.models.timing_model] -; ignore_errors = True - [mypy-pint.models.pulsar_binary] ignore_errors = True [mypy-pint.fitter] -ignore_errors = True +; ignore_errors = True [mypy-pint.polycos] ignore_errors = True diff --git a/src/pint/fitter.py b/src/pint/fitter.py index 16430d438..9021bb800 100644 --- a/src/pint/fitter.py +++ b/src/pint/fitter.py @@ -60,6 +60,7 @@ import contextlib import copy +from functools import cached_property from typing import List, Optional from warnings import warn @@ -72,8 +73,9 @@ import pint import pint.derived_quantities -import pint.utils +import pint.models import pint.models.timing_model +import pint.utils from pint.models.parameter import ( AngleParameter, boolParameter, @@ -108,64 +110,6 @@ "MaxiterReached", ] -try: - from functools import cached_property -except ImportError: - # not supported in python 3.7 - # This is just the code from python 3.8 - from _thread import RLock - - _NOT_FOUND = object() - - class cached_property: - def __init__(self, func): - self.func = func - self.attrname = None - self.__doc__ = func.__doc__ - self.lock = RLock() - - def __set_name__(self, owner, name): - if self.attrname is None: - self.attrname = name - elif name != self.attrname: - raise TypeError( - "Cannot assign the same cached_property to two different names " - f"({self.attrname!r} and {name!r})." - ) - - def __get__(self, instance, owner=None): - if instance is None: - return self - if self.attrname is None: - raise TypeError( - "Cannot use cached_property instance without calling __set_name__ on it." - ) - try: - cache = instance.__dict__ - except AttributeError: - # not all objects have __dict__ (e.g. class defines slots) - msg = ( - f"No '__dict__' attribute on {type(instance).__name__!r} " - f"instance to cache {self.attrname!r} property." - ) - raise TypeError(msg) from None - val = cache.get(self.attrname, _NOT_FOUND) - if val is _NOT_FOUND: - with self.lock: - # check if another thread filled cache while we awaited lock - val = cache.get(self.attrname, _NOT_FOUND) - if val is _NOT_FOUND: - val = self.func(instance) - try: - cache[self.attrname] = val - except TypeError: - msg = ( - f"The '__dict__' attribute on {type(instance).__name__!r} instance " - f"does not support item assignment for caching {self.attrname!r} property." - ) - raise TypeError(msg) from None - return val - class DegeneracyWarning(UserWarning): pass @@ -1119,17 +1063,19 @@ def _fit_toas( self.parameter_covariance_matrix.to_correlation_matrix() ) - for p, e in zip(self.current_state.params, self.errors): + for p, error in zip(self.current_state.params, self.errors): try: # I don't know why this fails with multiprocessing, but bypass if it does with contextlib.suppress(ValueError): - log.trace(f"Setting {getattr(self.model, p)} uncertainty to {e}") + log.trace( + f"Setting {getattr(self.model, p)} uncertainty to {error}" + ) pm = getattr(self.model, p) except AttributeError: if p != "Offset": log.warning(f"Unexpected parameter {p}") else: - pm.uncertainty = e * pm.units + pm.uncertainty = error * pm.units self.update_model(self.current_state.chi2) if exception is not None: raise StepProblem( diff --git a/src/pint/models/__init__.py b/src/pint/models/__init__.py index e08e69965..f55d4543c 100644 --- a/src/pint/models/__init__.py +++ b/src/pint/models/__init__.py @@ -45,6 +45,45 @@ from pint.models.wave import Wave from pint.models.wavex import WaveX +__all__ = [ + "AbsPhase", + "AstrometryEcliptic", + "AstrometryEquatorial", + "BinaryBT", + "BinaryBTPiecewise", + "BinaryDD", + "BinaryDDS", + "BinaryDDGR", + "BinaryDDK", + "BinaryELL1", + "BinaryELL1H", + "BinaryELL1k", + "DelayJump", + "DispersionDM", + "DispersionDMX", + "DMWaveX", + "EcorrNoise", + "FD", + "FDJump", + "Glitch", + "IFunc", + "PhaseJump", + "PiecewiseSpindown", + "PLRedNoise", + "ScaleToaError", + "SolarSystemShapiro", + "SolarWindDispersion", + "SolarWindDispersionX", + "Spindown", + "TroposphereDelay", + "Wave", + "WaveX", + "get_model", + "get_model_and_toas", + "TimingModel", + "DEFAULT_ORDER", +] + # Define a standard basic model StandardTimingModel = TimingModel( "StandardTimingModel", From d37bd123143312b704dfd9e1238db6983214bce4 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Wed, 28 Feb 2024 19:33:38 +0000 Subject: [PATCH 23/28] Get fitter working --- mypy.ini | 3 --- 1 file changed, 3 deletions(-) diff --git a/mypy.ini b/mypy.ini index e3cf3c787..0b149e251 100644 --- a/mypy.ini +++ b/mypy.ini @@ -27,9 +27,6 @@ ignore_errors = True [mypy-pint.models.pulsar_binary] ignore_errors = True -[mypy-pint.fitter] -; ignore_errors = True - [mypy-pint.polycos] ignore_errors = True From e34d35cedb5362036a08b9993a8dd35600d65dca Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Sun, 3 Mar 2024 15:56:07 +0000 Subject: [PATCH 24/28] Cleaned up all ignore_errors --- mypy.ini | 40 ++---- src/pint/derived_quantities.py | 131 +++++++++--------- src/pint/gridutils.py | 3 +- .../stand_alone_psr_binaries/DDGR_model.py | 13 +- .../stand_alone_psr_binaries/DDH_model.py | 5 +- src/pint/output/publish.py | 19 ++- src/pint/pintk/plk.py | 9 +- src/pint/polycos.py | 17 ++- src/pint/templates/lceprimitives.py | 7 +- src/pint/templates/lcprimitives.py | 1 + src/pint/templates/lctemplate.py | 8 -- 11 files changed, 124 insertions(+), 129 deletions(-) diff --git a/mypy.ini b/mypy.ini index 0b149e251..34e80cf34 100644 --- a/mypy.ini +++ b/mypy.ini @@ -10,41 +10,18 @@ local_partial_types = True no_implicit_reexport = True strict_equality = True -[mypy-pint.templates.*] -ignore_errors = True - -[mypy-pint.derived_quantities] -ignore_errors = True - [mypy-pint.observatory.*] allow_untyped_globals = True warn_unreachable = False warn_return_any = False -[mypy-pint.models.stand_alone_psr_binaries.*] -ignore_errors = True - -[mypy-pint.models.pulsar_binary] -ignore_errors = True - -[mypy-pint.polycos] -ignore_errors = True - -[mypy-pint.output.publish] -ignore_errors = True - -[mypy-pint.gridutils] -ignore_errors = True - -[mypy-pint.scripts.*] -ignore_errors = True - -[mypy-pint.pintk.plk] -ignore_errors = True - [mypy-pint.extern.*] +; external code, don't worry about it ignore_errors = True +; Other libraries that might not have type information +; some of them seem like they should? maybe we need new versions? + [mypy-astropy.*] ignore_missing_imports = True @@ -68,3 +45,12 @@ ignore_missing_imports = True [mypy-uncertainties.*] ignore_missing_imports = True + +[mypy-fftfit] +ignore_missing_imports = True + +[mypy-pathos.*] +ignore_missing_imports = True + +[mypy-corner] +ignore_missing_imports = True \ No newline at end of file diff --git a/src/pint/derived_quantities.py b/src/pint/derived_quantities.py index 0d3646078..87ff423fe 100644 --- a/src/pint/derived_quantities.py +++ b/src/pint/derived_quantities.py @@ -1,5 +1,4 @@ -"""Functions to compute various derived quantities from pulsar spin parameters, masses, etc. -""" +"""Functions to compute various derived quantities from pulsar spin parameters, masses, etc.""" import astropy.constants as const import astropy.units as u @@ -128,7 +127,7 @@ def pferrs(porf, porferr, pdorfd=None, pdorfderr=None): @u.quantity_input(fo=u.Hz) -def pulsar_age(f: u.Hz, fdot: u.Hz / u.s, n=3, fo=1e99 * u.Hz): +def pulsar_age(f: u.Quantity[u.Hz], fdot: u.Quantity[u.Hz / u.s], n=3, fo=1e99 * u.Hz): r"""Compute pulsar characteristic age Return the age of a pulsar given the spin frequency @@ -171,7 +170,9 @@ def pulsar_age(f: u.Hz, fdot: u.Hz / u.s, n=3, fo=1e99 * u.Hz): @u.quantity_input(I=u.g * u.cm**2) -def pulsar_edot(f: u.Hz, fdot: u.Hz / u.s, I=1.0e45 * u.g * u.cm**2): +def pulsar_edot( + f: u.Quantity[u.Hz], fdot: u.Quantity[u.Hz / u.s], I=1.0e45 * u.g * u.cm**2 +): r"""Compute pulsar spindown energy loss rate Return the pulsar `Edot` (:math:`\dot E`, in erg/s) given the spin frequency `f` and @@ -207,7 +208,7 @@ def pulsar_edot(f: u.Hz, fdot: u.Hz / u.s, I=1.0e45 * u.g * u.cm**2): @u.quantity_input -def pulsar_B(f: u.Hz, fdot: u.Hz / u.s): +def pulsar_B(f: u.Quantity[u.Hz], fdot: u.Quantity[u.Hz / u.s]): r"""Compute pulsar surface magnetic field Return the estimated pulsar surface magnetic field strength @@ -242,7 +243,7 @@ def pulsar_B(f: u.Hz, fdot: u.Hz / u.s): @u.quantity_input -def pulsar_B_lightcyl(f: u.Hz, fdot: u.Hz / u.s): +def pulsar_B_lightcyl(f: u.Quantity[u.Hz], fdot: u.Quantity[u.Hz / u.s]): r"""Compute pulsar magnetic field at the light cylinder Return the estimated pulsar magnetic field strength at the @@ -373,59 +374,61 @@ def mass_funct2(mp: u.Msun, mc: u.Msun, i: u.deg): def pulsar_mass(pb: u.d, x: u.cm, mc: u.Msun, i: u.deg): r"""Compute pulsar mass from orbital parameters - Return the pulsar mass (in solar mass units) for a binary. - Can handle scalar or array inputs. - - Parameters - ---------- - pb : astropy.units.Quantity - Binary orbital period - x : astropy.units.Quantity - Projected pulsar semi-major axis (aka ASINI) in ``pint.ls`` - mc : astropy.units.Quantity - Companion mass in ``u.solMass`` - i : astropy.coordinates.Angle or astropy.units.Quantity - Inclination angle, in ``u.deg`` or ``u.rad`` - - Returns - ------- - mass : astropy.units.Quantity - In ``u.solMass`` - - Raises - ------ - astropy.units.UnitsError - If the input data are not appropriate quantities - TypeError - If the input data are not quantities - - Example - ------- - >>> import pint - >>> import pint.derived_quantities - >>> from astropy import units as u - >>> print(pint.derived_quantities.pulsar_mass(2*u.hr, .2*pint.ls, 0.5*u.Msun, 60*u.deg)) - 7.6018341985817885 solMass - - - Notes - ------- - This forms a quadratic equation of the form: - :math:`a M_p^2 + b M_p + c = 0`` - - with: - - - :math:`a = f(P_b,x)` (the mass function) - - :math:`b = 2 f(P_b,x) M_c` - - :math:`c = f(P_b,x) M_c^2 - M_c\sin^3 i` - - except the discriminant simplifies to: - :math:`4f(P_b,x) M_c^3 \sin^3 i` - - solve it directly - this has to be the positive branch of the quadratic - because the vertex is at :math:`-M_c`, so - the negative branch will always be < 0 + Return the pulsar mass (in solar mass units) for a binary. + Can handle scalar or array inputs. + + Parameters + ---------- + pb : astropy.units.Quantity + Binary orbital period + x : astropy.units.Quantity + Projected pulsar semi-major axis (aka ASINI) in ``pint.ls`` + mc : astropy.units.Quantit[mypy-pint.templates.*] + ; ignore_errors = True + y + Companion mass in ``u.solMass`` + i : astropy.coordinates.Angle or astropy.units.Quantity + Inclination angle, in ``u.deg`` or ``u.rad`` + + Returns + ------- + mass : astropy.units.Quantity + In ``u.solMass`` + + Raises + ------ + astropy.units.UnitsError + If the input data are not appropriate quantities + TypeError + If the input data are not quantities + + Example + ------- + >>> import pint + >>> import pint.derived_quantities + >>> from astropy import units as u + >>> print(pint.derived_quantities.pulsar_mass(2*u.hr, .2*pint.ls, 0.5*u.Msun, 60*u.deg)) + 7.6018341985817885 solMass + + + Notes + ------- + This forms a quadratic equation of the form: + :math:`a M_p^2 + b M_p + c = 0`` + + with: + + - :math:`a = f(P_b,x)` (the mass function) + - :math:`b = 2 f(P_b,x) M_c` + - :math:`c = f(P_b,x) M_c^2 - M_c\sin^3 i` + + except the discriminant simplifies to: + :math:`4f(P_b,x) M_c^3 \sin^3 i` + + solve it directly + this has to be the positive branch of the quadratic + because the vertex is at :math:`-M_c`, so + the negative branch will always be < 0 """ massfunct = mass_funct(pb, x) @@ -869,7 +872,11 @@ def dth(mp: u.Msun, mc: u.Msun, pb: u.d): @u.quantity_input -def omdot_to_mtot(omdot: u.deg / u.yr, pb: u.d, e: u.dimensionless_unscaled): +def omdot_to_mtot( + omdot: u.Quantity[u.deg / u.yr], + pb: u.Quantity[u.d], + e: u.Quantity[u.dimensionless_unscaled], +): r"""Determine total mass from Post-Keplerian longitude of periastron precession rate omdot, assuming general relativity. @@ -981,7 +988,7 @@ def a1sini(mp, mc, pb, i=90 * u.deg): @u.quantity_input -def shklovskii_factor(pmtot: u.mas / u.yr, D: u.kpc): +def shklovskii_factor(pmtot: u.Quantity[u.mas / u.yr], D: u.Quantity[u.kpc]): r"""Compute magnitude of Shklovskii correction factor. Computes the Shklovskii correction factor, as defined in Eq 8.12 of Lorimer & Kramer (2005) [10]_ @@ -1019,7 +1026,7 @@ def shklovskii_factor(pmtot: u.mas / u.yr, D: u.kpc): @u.quantity_input -def dispersion_slope(dm: pint.dmu): +def dispersion_slope(dm: u.Quantity[pint.dmu]): """Compute the dispersion slope. This is equal to DMconst * DM. diff --git a/src/pint/gridutils.py b/src/pint/gridutils.py index 4e2ec06d7..521793b8c 100644 --- a/src/pint/gridutils.py +++ b/src/pint/gridutils.py @@ -1,4 +1,5 @@ """Tools for building chi-squared grids.""" + import concurrent.futures import copy import multiprocessing @@ -12,7 +13,7 @@ try: from tqdm import tqdm except ModuleNotFoundError: - tqdm = None + tqdm = None # type: ignore from astropy.utils.console import ProgressBar diff --git a/src/pint/models/stand_alone_psr_binaries/DDGR_model.py b/src/pint/models/stand_alone_psr_binaries/DDGR_model.py index 85c40fd39..86e5e8825 100644 --- a/src/pint/models/stand_alone_psr_binaries/DDGR_model.py +++ b/src/pint/models/stand_alone_psr_binaries/DDGR_model.py @@ -1,4 +1,5 @@ """The DDGR model - Damour and Deruelle with GR assumed""" + import astropy.constants as c import astropy.units as u import numpy as np @@ -618,29 +619,29 @@ def d_beta_d_M2(self): * self.d_omega_d_M2() ) - @SINI.setter + @SINI.setter # type: ignore[no-redef, attr-defined] def SINI(self, val): log.debug( "DDGR model uses MTOT to derive the inclination angle. SINI will not be used." ) - @PBDOT.setter + @PBDOT.setter # type: ignore[no-redef, attr-defined] def PBDOT(self, val): log.debug("DDGR model uses MTOT to derive PBDOT. PBDOT will not be used.") - @OMDOT.setter + @OMDOT.setter # type: ignore[no-redef, attr-defined] def OMDOT(self, val): log.debug("DDGR model uses MTOT to derive OMDOT. OMDOT will not be used.") - @GAMMA.setter + @GAMMA.setter # type: ignore[no-redef, attr-defined] def GAMMA(self, val): log.debug("DDGR model uses MTOT to derive GAMMA. GAMMA will not be used.") - @DR.setter + @DR.setter # type: ignore[no-redef, attr-defined] def DR(self, val): log.debug("DDGR model uses MTOT to derive Dr. Dr will not be used.") - @DTH.setter + @DTH.setter # type: ignore[no-redef, attr-defined] def DTH(self, val): log.debug("DDGR model uses MTOT to derive Dth. Dth will not be used.") diff --git a/src/pint/models/stand_alone_psr_binaries/DDH_model.py b/src/pint/models/stand_alone_psr_binaries/DDH_model.py index 3c4d59dcb..5fe95bcf6 100644 --- a/src/pint/models/stand_alone_psr_binaries/DDH_model.py +++ b/src/pint/models/stand_alone_psr_binaries/DDH_model.py @@ -1,4 +1,5 @@ """The DDS model - Damour and Deruelle with alternate Shapiro delay parametrization.""" + import astropy.constants as c import astropy.units as u import numpy as np @@ -69,13 +70,13 @@ def SINI(self): def M2(self): return self.H3 / self.STIGMA**3 / Tsun.value - @SINI.setter + @SINI.setter # type: ignore[no-redef, attr-defined] def SINI(self, val): log.debug( "DDH model uses H3/STIGMA as Shapiro delay parameter. SINI will not be used." ) - @M2.setter + @M2.setter # type: ignore[no-redef, attr-defined] def M2(self, val): log.debug( "DDH model uses H3/STIGMA as Shapiro delay parameter. M2 will not be used." diff --git a/src/pint/output/publish.py b/src/pint/output/publish.py index e6380eb1f..9832eba6b 100644 --- a/src/pint/output/publish.py +++ b/src/pint/output/publish.py @@ -1,24 +1,28 @@ """Generate LaTeX summary of a timing model and TOAs.""" + +from io import StringIO +from typing import List, Union + +import numpy as np from pint.models import ( - TimingModel, - DispersionDMX, FD, + AbsPhase, + DispersionDMX, Glitch, PhaseJump, SolarWindDispersionX, - AbsPhase, + TimingModel, Wave, ) +from pint.models.timing_model import Component from pint.models.dispersion_model import DispersionJump from pint.models.noise_model import NoiseComponent from pint.models.parameter import ( Parameter, funcParameter, ) -from pint.toa import TOAs from pint.residuals import Residuals, WidebandTOAResiduals -from io import StringIO -import numpy as np +from pint.toa import TOAs def publish_param(param: Parameter): @@ -91,6 +95,7 @@ def publish( else "WLS" ) + res: Union[Residuals, WidebandTOAResiduals] if toas.is_wideband(): res = WidebandTOAResiduals(toas, model) toares = res.toa @@ -117,7 +122,7 @@ def publish( "BINARY", ] - exclude_components = [Wave] + exclude_components: List[type[Component]] = [Wave] if not include_dmx: exclude_components.append(DispersionDMX) if not include_jumps: diff --git a/src/pint/pintk/plk.py b/src/pint/pintk/plk.py index 1ea713b84..2e6c7d6c9 100644 --- a/src/pint/pintk/plk.py +++ b/src/pint/pintk/plk.py @@ -1,6 +1,7 @@ """ Interactive emulator of tempo2 plk """ + import copy import os import sys @@ -28,9 +29,9 @@ try: - from matplotlib.backends.backend_tkagg import NavigationToolbar2Tk + from matplotlib.backends.backend_tkagg import NavigationToolbar2Tk # type: ignore[attr-defined] except ImportError: - from matplotlib.backends.backend_tkagg import ( + from matplotlib.backends.backend_tkagg import ( # type: ignore[no-redef,attr-defined] NavigationToolbar2TkAgg as NavigationToolbar2Tk, ) @@ -587,11 +588,11 @@ class PlkToolbar(NavigationToolbar2Tk): necessary selections/un-selections on points """ - toolitems = [ + toolitems = tuple( t for t in NavigationToolbar2Tk.toolitems if t[0] in ("Home", "Back", "Forward", "Pan", "Zoom", "Save") - ] + ) class PlkActionsWidget(tk.Frame): diff --git a/src/pint/polycos.py b/src/pint/polycos.py index 0542d3304..e1f160fd9 100644 --- a/src/pint/polycos.py +++ b/src/pint/polycos.py @@ -1,7 +1,7 @@ r"""Polynomial coefficients for phase prediction Polycos designed to predict the pulsar's phase and pulse-period over a -given interval using polynomial expansions. +given interval using polynomial expansions. The pulse phase and frequency at time T are then calculated as: @@ -27,25 +27,28 @@ >>> from pint.polycos import Polycos >>> model = get_model(filename) >>> p = Polycos.generate_polycos(model, 50000, 50001, "AO", 144, 12, 1400) - + References ---------- http://tempo.sourceforge.net/ref_man_sections/tz-polyco.txt """ + +from collections import OrderedDict +from collections.abc import Callable +from typing import Dict, List, Union + import astropy.table as table import astropy.units as u import numpy as np from astropy.io import registry from astropy.time import Time -from collections import OrderedDict - from loguru import logger as log try: from tqdm import tqdm -except (ModuleNotFoundError, ImportError) as e: +except (ModuleNotFoundError, ImportError): - def tqdm(*args, **kwargs): + def tqdm(*args, **kwargs): # type: ignore return args[0] if args else kwargs.get("iterable", None) @@ -483,7 +486,7 @@ class Polycos: """ # loaded formats - polycoFormats = [] + polycoFormats: List[Dict[str, Union[str, Callable]]] = [] @classmethod def _register(cls, formatlist=_polycoFormats): diff --git a/src/pint/templates/lceprimitives.py b/src/pint/templates/lceprimitives.py index ad0a1a0d0..93f9b8750 100644 --- a/src/pint/templates/lceprimitives.py +++ b/src/pint/templates/lceprimitives.py @@ -1,10 +1,6 @@ from pint.templates.lcprimitives import * -def isvector(x): - return len(np.asarray(x).shape) > 0 - - def edep_gradient(self, grad_func, phases, log10_ens=3, free=False): """Return the analytic gradient of a general LCEPrimitive. @@ -232,7 +228,8 @@ def _einit(self): self.slope_bounds[2] = [-0.3, 0.3] -class LCELorentzian(LCEWrappedFunction, LCLorentzian): +# LCWrappedFunction.derivative doesn't accept index but LCLorentzian.derivative does +class LCELorentzian(LCEWrappedFunction, LCLorentzian): # type: ignore[misc] """Represent a (wrapped) Lorentzian peak. Parameters diff --git a/src/pint/templates/lcprimitives.py b/src/pint/templates/lcprimitives.py index cbcf665ba..de75583cf 100644 --- a/src/pint/templates/lcprimitives.py +++ b/src/pint/templates/lcprimitives.py @@ -661,6 +661,7 @@ def hessian(self, phases, log10_ens=3, free=False): # results[i,:] += gn[i] return results[self.free, self.free] if free else results + # This derivative doesn't accept an index argument, but LCLorentzian does def derivative(self, phases, log10_ens=3, order=1): """Return the phase gradient (dprim/dphi) at a vector of phases. diff --git a/src/pint/templates/lctemplate.py b/src/pint/templates/lctemplate.py index 8e14f0834..2fcf0fdc3 100644 --- a/src/pint/templates/lctemplate.py +++ b/src/pint/templates/lctemplate.py @@ -20,10 +20,6 @@ log = logging.getLogger(__name__) -def isvector(x): - return len(np.asarray(x).shape) > 0 - - class LCTemplate: """Manage a lightcurve template (collection of LCPrimitive objects). @@ -1071,7 +1067,3 @@ def check_gradient_derivative(templ): for i in range(gd.shape[0]): print(np.max(np.abs(gd[i] - ngd[i]))) return pcs, gd, ngd - - -def isvector(x): - return len(np.asarray(x).shape) > 0 From 843911e3b63c50e135a1c9cf3d9d8233cfccd6ee Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Mon, 4 Mar 2024 18:58:58 +0000 Subject: [PATCH 25/28] Tidy derived_quantities for oldestdeps --- src/pint/derived_quantities.py | 38 ++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/pint/derived_quantities.py b/src/pint/derived_quantities.py index 87ff423fe..c0d072545 100644 --- a/src/pint/derived_quantities.py +++ b/src/pint/derived_quantities.py @@ -126,8 +126,10 @@ def pferrs(porf, porferr, pdorfd=None, pdorfderr=None): return [forp, forperr, fdorpd, fdorpderr] -@u.quantity_input(fo=u.Hz) -def pulsar_age(f: u.Quantity[u.Hz], fdot: u.Quantity[u.Hz / u.s], n=3, fo=1e99 * u.Hz): +@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s, fo=u.Hz) +def pulsar_age( + f: u.Quantity, fdot: u.Quantity, n: float = 3, fo: u.Quantity = 1e99 * u.Hz +) -> u.Quantity: r"""Compute pulsar characteristic age Return the age of a pulsar given the spin frequency @@ -169,10 +171,10 @@ def pulsar_age(f: u.Quantity[u.Hz], fdot: u.Quantity[u.Hz / u.s], n=3, fo=1e99 * return (-f / ((n - 1.0) * fdot) * (1.0 - (f / fo) ** (n - 1.0))).to(u.yr) -@u.quantity_input(I=u.g * u.cm**2) +@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s, I=u.g * u.cm**2) def pulsar_edot( - f: u.Quantity[u.Hz], fdot: u.Quantity[u.Hz / u.s], I=1.0e45 * u.g * u.cm**2 -): + f: u.Quantity, fdot: u.Quantity, I: u.Quantity = 1.0e45 * u.g * u.cm**2 +) -> u.Quantity: r"""Compute pulsar spindown energy loss rate Return the pulsar `Edot` (:math:`\dot E`, in erg/s) given the spin frequency `f` and @@ -207,8 +209,8 @@ def pulsar_edot( return (-4.0 * np.pi**2 * I * f * fdot).to(u.erg / u.s) -@u.quantity_input -def pulsar_B(f: u.Quantity[u.Hz], fdot: u.Quantity[u.Hz / u.s]): +@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s) +def pulsar_B(f: u.Quantity, fdot: u.Quantity) -> u.Quantity: r"""Compute pulsar surface magnetic field Return the estimated pulsar surface magnetic field strength @@ -242,8 +244,8 @@ def pulsar_B(f: u.Quantity[u.Hz], fdot: u.Quantity[u.Hz / u.s]): return 3.2e19 * u.G * np.sqrt(-fdot.to_value(u.Hz / u.s) / f.to_value(u.Hz) ** 3.0) -@u.quantity_input -def pulsar_B_lightcyl(f: u.Quantity[u.Hz], fdot: u.Quantity[u.Hz / u.s]): +@u.quantity_input(f=u.Hz, fdot=u.Hz / u.s) +def pulsar_B_lightcyl(f: u.Quantity, fdot: u.Quantity) -> u.Quantity: r"""Compute pulsar magnetic field at the light cylinder Return the estimated pulsar magnetic field strength at the @@ -871,12 +873,12 @@ def dth(mp: u.Msun, mc: u.Msun, pb: u.d): ).decompose() -@u.quantity_input +@u.quantity_input(omdot=u.deg / u.yr, pb=u.d, e=u.dimensionless_unscaled) def omdot_to_mtot( - omdot: u.Quantity[u.deg / u.yr], - pb: u.Quantity[u.d], - e: u.Quantity[u.dimensionless_unscaled], -): + omdot: u.Quantity, + pb: u.Quantity, + e: u.Quantity, +) -> u.Quantity: r"""Determine total mass from Post-Keplerian longitude of periastron precession rate omdot, assuming general relativity. @@ -987,8 +989,8 @@ def a1sini(mp, mc, pb, i=90 * u.deg): ).to(pint.ls) -@u.quantity_input -def shklovskii_factor(pmtot: u.Quantity[u.mas / u.yr], D: u.Quantity[u.kpc]): +@u.quantity_input(pmtot=u.mas / u.yr, d=u.kpc) +def shklovskii_factor(pmtot: u.Quantity, D: u.Quantity) -> u.Quantity: r"""Compute magnitude of Shklovskii correction factor. Computes the Shklovskii correction factor, as defined in Eq 8.12 of Lorimer & Kramer (2005) [10]_ @@ -1025,8 +1027,8 @@ def shklovskii_factor(pmtot: u.Quantity[u.mas / u.yr], D: u.Quantity[u.kpc]): return a_s -@u.quantity_input -def dispersion_slope(dm: u.Quantity[pint.dmu]): +@u.quantity_input(dm=pint.dmu) +def dispersion_slope(dm: u.Quantity) -> u.Quantity: """Compute the dispersion slope. This is equal to DMconst * DM. From 5710867e393a920d025c0aec914ab87a630242d6 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Tue, 5 Mar 2024 18:55:19 +0000 Subject: [PATCH 26/28] return type annotations won't work with oldestdeps --- src/pint/derived_quantities.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/pint/derived_quantities.py b/src/pint/derived_quantities.py index c0d072545..a6853fd75 100644 --- a/src/pint/derived_quantities.py +++ b/src/pint/derived_quantities.py @@ -129,7 +129,7 @@ def pferrs(porf, porferr, pdorfd=None, pdorfderr=None): @u.quantity_input(f=u.Hz, fdot=u.Hz / u.s, fo=u.Hz) def pulsar_age( f: u.Quantity, fdot: u.Quantity, n: float = 3, fo: u.Quantity = 1e99 * u.Hz -) -> u.Quantity: +): r"""Compute pulsar characteristic age Return the age of a pulsar given the spin frequency @@ -174,7 +174,7 @@ def pulsar_age( @u.quantity_input(f=u.Hz, fdot=u.Hz / u.s, I=u.g * u.cm**2) def pulsar_edot( f: u.Quantity, fdot: u.Quantity, I: u.Quantity = 1.0e45 * u.g * u.cm**2 -) -> u.Quantity: +): r"""Compute pulsar spindown energy loss rate Return the pulsar `Edot` (:math:`\dot E`, in erg/s) given the spin frequency `f` and @@ -210,7 +210,7 @@ def pulsar_edot( @u.quantity_input(f=u.Hz, fdot=u.Hz / u.s) -def pulsar_B(f: u.Quantity, fdot: u.Quantity) -> u.Quantity: +def pulsar_B(f: u.Quantity, fdot: u.Quantity): r"""Compute pulsar surface magnetic field Return the estimated pulsar surface magnetic field strength @@ -245,7 +245,7 @@ def pulsar_B(f: u.Quantity, fdot: u.Quantity) -> u.Quantity: @u.quantity_input(f=u.Hz, fdot=u.Hz / u.s) -def pulsar_B_lightcyl(f: u.Quantity, fdot: u.Quantity) -> u.Quantity: +def pulsar_B_lightcyl(f: u.Quantity, fdot: u.Quantity): r"""Compute pulsar magnetic field at the light cylinder Return the estimated pulsar magnetic field strength at the @@ -878,7 +878,7 @@ def omdot_to_mtot( omdot: u.Quantity, pb: u.Quantity, e: u.Quantity, -) -> u.Quantity: +): r"""Determine total mass from Post-Keplerian longitude of periastron precession rate omdot, assuming general relativity. @@ -990,7 +990,7 @@ def a1sini(mp, mc, pb, i=90 * u.deg): @u.quantity_input(pmtot=u.mas / u.yr, d=u.kpc) -def shklovskii_factor(pmtot: u.Quantity, D: u.Quantity) -> u.Quantity: +def shklovskii_factor(pmtot: u.Quantity, D: u.Quantity): r"""Compute magnitude of Shklovskii correction factor. Computes the Shklovskii correction factor, as defined in Eq 8.12 of Lorimer & Kramer (2005) [10]_ @@ -1028,7 +1028,7 @@ def shklovskii_factor(pmtot: u.Quantity, D: u.Quantity) -> u.Quantity: @u.quantity_input(dm=pint.dmu) -def dispersion_slope(dm: u.Quantity) -> u.Quantity: +def dispersion_slope(dm: u.Quantity): """Compute the dispersion slope. This is equal to DMconst * DM. From da07be47dc34a0b9ba613006f0624c99b8b5f3ff Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Tue, 5 Mar 2024 20:31:59 +0000 Subject: [PATCH 27/28] Annotating pint.utils --- src/pint/observatory/__init__.py | 4 +- src/pint/utils.py | 208 +++++++++++++++++++------------ 2 files changed, 128 insertions(+), 84 deletions(-) diff --git a/src/pint/observatory/__init__.py b/src/pint/observatory/__init__.py index 8564e2a38..caf3af767 100644 --- a/src/pint/observatory/__init__.py +++ b/src/pint/observatory/__init__.py @@ -577,10 +577,10 @@ def compare_t2_observatories_dat(t2dir: Optional[str] = None) -> Dict[str, List[ with open(filename) as f: for line in interesting_lines(f, comments="#"): try: - x, y, z, full_name, short_name = line.split() + x_str, y_str, z_str, full_name, short_name = line.split() except ValueError as e: raise ValueError(f"unrecognized line '{line}'") from e - x, y, z = float(x), float(y), float(z) + x, y, z = float(x_str), float(y_str), float(z_str) full_name, short_name = full_name.lower(), short_name.lower() topo_obs_entry = textwrap.dedent( f""" diff --git a/src/pint/utils.py b/src/pint/utils.py index 1331ef53b..3b72c175b 100644 --- a/src/pint/utils.py +++ b/src/pint/utils.py @@ -28,6 +28,7 @@ has moved to :mod:`pint.simulation`. """ + import configparser import datetime import getpass @@ -37,11 +38,13 @@ import re import sys import textwrap +import warnings +from collections.abc import Generator, Iterable from contextlib import contextmanager +from copy import deepcopy from pathlib import Path +from typing import IO, Any, Optional, Tuple, Union, List, Dict, Type, Mapping, cast from warnings import warn -from scipy.optimize import minimize -from numdifftools import Hessian import astropy.constants as const import astropy.coordinates as coords @@ -50,16 +53,15 @@ from astropy import constants from astropy.time import Time from loguru import logger as log -from scipy.special import fdtrc +from numdifftools import Hessian from scipy.linalg import cho_factor, cho_solve -from copy import deepcopy -import warnings +from scipy.optimize import minimize +from scipy.special import fdtrc import pint import pint.pulsar_ecliptic from pint.toa_select import TOASelect - __all__ = [ "PINTPrecisionError", "check_longdouble_precision", @@ -114,8 +116,17 @@ "get_unit", ] -COLOR_NAMES = ["black", "red", "green", "yellow", "blue", "magenta", "cyan", "white"] -TEXT_ATTRIBUTES = [ +COLOR_NAMES: list[str] = [ + "black", + "red", + "green", + "yellow", + "blue", + "magenta", + "cyan", + "white", +] +TEXT_ATTRIBUTES: list[str] = [ "normal", "bold", "subdued", @@ -145,7 +156,7 @@ def check_longdouble_precision(): return np.finfo(np.longdouble).eps < 2e-19 -def require_longdouble_precision(): +def require_longdouble_precision() -> None: """Raise an exception if long doubles do not have enough precision. Raises RuntimeError if PINT cannot be run with high precision on this @@ -181,7 +192,13 @@ class PosVel: """ - def __init__(self, pos, vel, obj=None, origin=None): + def __init__( + self, + pos: Union[u.Quantity, np.ndarray], + vel: Union[u.Quantity, np.ndarray], + obj=None, + origin=None, + ): if len(pos) != 3: raise ValueError(f"Position vector has length {len(pos)} instead of 3") self.pos = pos if isinstance(pos, u.Quantity) else np.asarray(pos) @@ -207,13 +224,13 @@ def __init__(self, pos, vel, obj=None, origin=None): self.origin = origin # FIXME: what about dtype compatibility? - def _has_labels(self): + def _has_labels(self) -> bool: return (self.obj is not None) and (self.origin is not None) - def __neg__(self): + def __neg__(self) -> "PosVel": return PosVel(-self.pos, -self.vel, obj=self.origin, origin=self.obj) - def __add__(self, other): + def __add__(self, other: "PosVel") -> "PosVel": obj = None origin = None if self._has_labels() and other._has_labels(): @@ -234,17 +251,17 @@ def __add__(self, other): self.pos + other.pos, self.vel + other.vel, obj=obj, origin=origin ) - def __sub__(self, other): + def __sub__(self, other: "PosVel") -> "PosVel": return self.__add__(other.__neg__()) - def __str__(self): + def __str__(self) -> str: return ( f"PosVel({str(self.pos)}, {str(self.vel)} {self.origin}->{self.obj})" if self._has_labels() else f"PosVel({str(self.pos)}, {str(self.vel)})" ) - def __getitem__(self, k): + def __getitem__(self, k: Union[int, Tuple[int, ...]]) -> "PosVel": """Allow extraction of slices of the contained arrays""" colon = slice(None, None, None) ix = (colon,) + k if isinstance(k, tuple) else (colon, k) @@ -305,7 +322,7 @@ def check_all_partials(f, args, delta=1e-6, atol=1e-4, rtol=1e-4): raise -def has_astropy_unit(x): +def has_astropy_unit(x) -> bool: """Test whether x has a unit attribute containing an astropy unit. This is useful, because different data types can still have units @@ -328,7 +345,7 @@ class PrefixError(ValueError): pass -def split_prefixed_name(name): +def split_prefixed_name(name: str) -> Tuple[str, str, int]: """Split a prefixed name. Parameters @@ -365,17 +382,16 @@ def split_prefixed_name(name): """ for pt in prefix_pattern: - try: - prefix_part, index_part = pt.match(name).groups() + m = pt.match(name) + if m is not None: + prefix_part, index_part = m.groups() break - except AttributeError: - continue else: raise PrefixError(f"Unrecognized prefix name pattern '{name}'.") return prefix_part, index_part, int(index_part) -def taylor_horner(x, coeffs): +def taylor_horner(x: Union[float, np.ndarray, u.Quantity], coeffs): """Evaluate a Taylor series of coefficients at x via the Horner scheme. For example, if we want: 10 + 3*x/1! + 4*x^2/2! + 12*x^3/3! with @@ -444,7 +460,10 @@ def taylor_horner_deriv(x, coeffs, deriv_order=1): @contextmanager -def open_or_use(f, mode="r"): +def open_or_use( + f: Union[str, bytes, Path, IO[Any]], + mode: str = "r", +) -> Generator[IO[Any], None, None]: """Open a filename or use an open file. Specifically, if f is a string, try to use it as an argument to @@ -459,7 +478,7 @@ def open_or_use(f, mode="r"): yield f -def lines_of(f): +def lines_of(f: Union[str, bytes, Path, IO[str]]) -> Generator[str, None, None]: """Iterate over the lines of a file, an open file, or an iterator. If ``f`` is a string, try to open a file of that name. Otherwise @@ -472,7 +491,10 @@ def lines_of(f): yield from fo -def interesting_lines(lines, comments=None): +def interesting_lines( + lines: Iterable[str], + comments: Union[None, str, Iterable[Union[str]]] = None, +) -> Generator[str, None, None]: """Iterate over lines skipping whitespace and comments. Each line has its whitespace stripped and then it is checked whether @@ -480,6 +502,7 @@ def interesting_lines(lines, comments=None): a list of strings. """ + cc: Tuple[str, ...] if comments is None: cc = () elif isinstance(comments, (str, bytes)): @@ -490,8 +513,8 @@ def interesting_lines(lines, comments=None): cs = c.strip() if not cs or not c.startswith(cs): raise ValueError( - "Unable to deal with comments that start with whitespace, " - "but comment string {!r} was requested.".format(c) + f"Unable to deal with comments that start with whitespace, " + f"but comment string {c:!r} was requested." ) for ln in lines: ln = ln.strip() @@ -1077,7 +1100,7 @@ def dmxparse(fitter, save=False): } -def get_prefix_timerange(model, prefixname): +def get_prefix_timerange(model, prefixname: str) -> Tuple[Time, Time]: """Get time range for a prefix quantity like DMX or SWX Parameters @@ -1105,7 +1128,7 @@ def get_prefix_timerange(model, prefixname): return getattr(model, r1).quantity, getattr(model, r2).quantity -def get_prefix_timeranges(model, prefixname): +def get_prefix_timeranges(model, prefixname: str) -> Tuple[np.ndarray, Time, Time]: """Get all time ranges and indices for a prefix quantity like DMX or SWX Parameters @@ -1142,7 +1165,9 @@ def get_prefix_timeranges(model, prefixname): ) -def find_prefix_bytime(model, prefixname, t): +def find_prefix_bytime( + model, prefixname: str, t: Union[Time, u.Quantity] +) -> Union[int, np.ndarray]: """Identify matching index(es) for a prefix parameter like DMX Parameters @@ -1163,11 +1188,14 @@ def find_prefix_bytime(model, prefixname, t): indices, r1, r2 = get_prefix_timeranges(model, prefixname) matches = np.where((t >= r1) & (t < r2))[0] if len(matches) == 1: - matches = int(matches) - return indices[matches] + return int(indices[int(matches)]) + else: + return indices[matches] -def merge_dmx(model, index1, index2, value="mean", frozen=True): +def merge_dmx( + model, index1: int, index2: int, value: str = "mean", frozen: bool = True +) -> int: """Merge two DMX bins Parameters @@ -1197,7 +1225,7 @@ def merge_dmx(model, index1, index2, value="mean", frozen=True): ) if value.lower() == "first": dmx = getattr(model, f"DMX_{index1:04d}").quantity - elif value.lower == "second": + elif value.lower() == "second": dmx = getattr(model, f"DMX_{index2:04d}").quantity elif value.lower() == "mean": dmx = ( @@ -1205,14 +1233,13 @@ def merge_dmx(model, index1, index2, value="mean", frozen=True): + getattr(model, f"DMX_{index2:04d}").quantity ) / 2 # add the new one before we delete previous ones to make sure we have >=1 present - newindex = model.add_DMX_range(tstart, tend, dmx=dmx, frozen=frozen) + newindex: int = model.add_DMX_range(tstart, tend, dmx=dmx, frozen=frozen) model.remove_DMX_range([index1, index2]) return newindex -def split_dmx(model, time): - """ - Split an existing DMX bin at the desired time +def split_dmx(model, time: Time) -> Tuple[int, int]: + """Split an existing DMX bin at the desired time. Parameters ---------- @@ -1234,10 +1261,10 @@ def split_dmx(model, time): dmx_epochs = [f"{x:04d}" for x in DMX_mapping.keys()] DMX_R1 = np.zeros(len(dmx_epochs)) DMX_R2 = np.zeros(len(dmx_epochs)) - for ii, epoch in enumerate(dmx_epochs): - DMX_R1[ii] = getattr(model, "DMXR1_{:}".format(epoch)).value - DMX_R2[ii] = getattr(model, "DMXR2_{:}".format(epoch)).value - ii = np.where((time.mjd > DMX_R1) & (time.mjd < DMX_R2))[0] + for iii, epoch in enumerate(dmx_epochs): + DMX_R1[iii] = getattr(model, "DMXR1_{:}".format(epoch)).value + DMX_R2[iii] = getattr(model, "DMXR2_{:}".format(epoch)).value + ii: np.ndarray = np.where((time.mjd > DMX_R1) & (time.mjd < DMX_R2))[0] if len(ii) == 0: raise ValueError(f"Time {time} not in any DMX bins") ii = ii[0] @@ -1255,9 +1282,8 @@ def split_dmx(model, time): return index, newindex -def split_swx(model, time): - """ - Split an existing SWX bin at the desired time +def split_swx(model, time: Time) -> Tuple[int, int]: + """Split an existing SWX bin at the desired time. Parameters ---------- @@ -1270,7 +1296,6 @@ def split_swx(model, time): Index of existing bin that was split newindex : int Index of new bin that was added - """ try: SWX_mapping = model.get_prefix_mapping("SWX_") @@ -1279,9 +1304,9 @@ def split_swx(model, time): swx_epochs = [f"{x:04d}" for x in SWX_mapping.keys()] SWX_R1 = np.zeros(len(swx_epochs)) SWX_R2 = np.zeros(len(swx_epochs)) - for ii, epoch in enumerate(swx_epochs): - SWX_R1[ii] = getattr(model, "SWXR1_{:}".format(epoch)).value - SWX_R2[ii] = getattr(model, "SWXR2_{:}".format(epoch)).value + for iii, epoch in enumerate(swx_epochs): + SWX_R1[iii] = getattr(model, "SWXR1_{:}".format(epoch)).value + SWX_R2[iii] = getattr(model, "SWXR2_{:}".format(epoch)).value ii = np.where((time.mjd > SWX_R1) & (time.mjd < SWX_R2))[0] if len(ii) == 0: raise ValueError(f"Time {time} not in any SWX bins") @@ -1301,7 +1326,8 @@ def split_swx(model, time): def wavex_setup(model, T_span, freqs=None, n_freqs=None, freeze_params=False): - """ + """Set up a WaveX model. + Set-up a WaveX model based on either an array of user-provided frequencies or the wave number frequency calculation. Sine and Cosine amplitudes are initially set to zero @@ -1725,7 +1751,13 @@ def translate_wavex_to_wave(model): return new_model -def weighted_mean(arrin, weights_in, inputmean=None, calcerr=False, sdev=False): +def weighted_mean( + arrin: np.ndarray, + weights_in: np.ndarray, + inputmean: Optional[float] = None, + calcerr: bool = False, + sdev: bool = False, +) -> Union[Tuple[float, float], Tuple[float, float, float]]: """Compute weighted mean of input values Calculate the weighted mean, error, and optionally standard deviation of @@ -1736,10 +1768,10 @@ def weighted_mean(arrin, weights_in, inputmean=None, calcerr=False, sdev=False): Parameters ---------- arrin : array - Array containing the numbers whose weighted mean is desired. + Array containing the numbers whose weighted mean is desired. weights: array - A set of weights for each element in array. For measurements with - uncertainties, these should be 1/sigma^2. + A set of weights for each element in array. For measurements with + uncertainties, these should be 1/sigma^2. inputmean: float, optional An input mean value, around which the mean is calculated. calcerr : bool, optional @@ -1753,8 +1785,8 @@ def weighted_mean(arrin, weights_in, inputmean=None, calcerr=False, sdev=False): Returns ------- wmean, werr: tuple - A tuple of the weighted mean and error. If sdev=True the - tuple will also contain sdev: wmean,werr,wsdev + A tuple of the weighted mean and error. If sdev=True the + tuple will also contain sdev: wmean,werr,wsdev Notes ----- @@ -1839,7 +1871,7 @@ def ELL1_check( return False -def FTest(chi2_1, dof_1, chi2_2, dof_2): +def FTest(chi2_1: float, dof_1: int, chi2_2: float, dof_2: int) -> float: """Run F-test. Compute an F-test to see if a model with extra parameters is @@ -1876,7 +1908,7 @@ def FTest(chi2_1, dof_1, chi2_2, dof_2): delta_dof = dof_1 - dof_2 new_redchi2 = chi2_2 / dof_2 F = float((delta_chi2 / delta_dof) / new_redchi2) # fdtr doesn't like float128 - return fdtrc(delta_dof, dof_2, F) + return float(fdtrc(delta_dof, dof_2, F)) elif dof_1 == dof_2: log.warning("Models have equal degrees of freedom, cannot perform F-test.") return np.nan @@ -1887,7 +1919,9 @@ def FTest(chi2_1, dof_1, chi2_2, dof_2): return 1.0 -def add_dummy_distance(c, distance=1 * u.kpc): +def add_dummy_distance( + c: coords.SkyCoord, distance: u.Quantity = 1 * u.kpc +) -> coords.SkyCoord: """Adds a dummy distance to a SkyCoord object for applying proper motion Parameters @@ -1959,7 +1993,7 @@ def add_dummy_distance(c, distance=1 * u.kpc): return c -def remove_dummy_distance(c): +def remove_dummy_distance(c: coords.SkyCoord) -> coords.SkyCoord: """Removes a dummy distance from a SkyCoord object after applying proper motion Parameters @@ -2024,7 +2058,9 @@ def remove_dummy_distance(c): return c -def info_string(prefix_string="# ", comment=None, detailed=False): +def info_string( + prefix_string: str = "# ", comment: Optional[str] = None, detailed: bool = False +) -> str: """Returns an informative string about the current state of PINT. Adds: @@ -2132,7 +2168,7 @@ def info_string(prefix_string="# ", comment=None, detailed=False): # user-level git config c = git.GitConfigParser() - username = c.get_value("user", option="name") + f" ({getpass.getuser()})" + username = str(c.get_value("user", option="name")) + f" ({getpass.getuser()})" except (configparser.NoOptionError, configparser.NoSectionError, ImportError): username = getpass.getuser() @@ -2146,13 +2182,14 @@ def info_string(prefix_string="# ", comment=None, detailed=False): } if detailed: - from numpy import __version__ as numpy_version - from scipy import __version__ as scipy_version from astropy import __version__ as astropy_version from erfa import __version__ as erfa_version from jplephem import __version__ as jpleph_version + from loguru import __version__ as loguru_version # type: ignore[attr-defined] from matplotlib import __version__ as matplotlib_version - from loguru import __version__ as loguru_version + from numpy import __version__ as numpy_version + from scipy import __version__ as scipy_version + from pint import __file__ as pint_file info_dict.update( @@ -2205,7 +2242,7 @@ def info_string(prefix_string="# ", comment=None, detailed=False): return s -def list_parameters(class_=None): +def list_parameters(class_=None) -> List[Dict]: """List parameters understood by PINT. Parameters @@ -2265,7 +2302,7 @@ def list_parameters(class_=None): results = {} ct = pint.models.timing_model.Component.component_types.copy() - ct["TimingModel"] = pint.models.timing_model.TimingModel + ct["TimingModel"] = pint.models.timing_model.TimingModel # type: ignore[assignment] for v in ct.values(): for d in list_parameters(v): n = d["name"] @@ -2284,7 +2321,12 @@ def list_parameters(class_=None): return sorted(results.values(), key=lambda d: d["name"]) -def colorize(text, fg_color=None, bg_color=None, attribute=None): +def colorize( + text: str, + fg_color: Optional[str] = None, + bg_color: Optional[str] = None, + attribute: Optional[str] = None, +) -> str: """Colorizes a string (including unicode strings) for printing on the terminal For an example of usage, as well as a demonstration as to what the @@ -2311,9 +2353,11 @@ def colorize(text, fg_color=None, bg_color=None, attribute=None): The colorized string using the defined text attribute. """ COLOR_FORMAT = "\033[%dm\033[%d;%dm%s\033[0m" - FOREGROUND = dict(zip(COLOR_NAMES, list(range(30, 38)))) - BACKGROUND = dict(zip(COLOR_NAMES, list(range(40, 48)))) - ATTRIBUTE = dict(zip(TEXT_ATTRIBUTES, [0, 1, 2, 3, 4, 5, 7, 8])) + FOREGROUND: Dict[Optional[str], int] = dict(zip(COLOR_NAMES, list(range(30, 38)))) + BACKGROUND: Dict[Optional[str], int] = dict(zip(COLOR_NAMES, list(range(40, 48)))) + ATTRIBUTE: Dict[Optional[str], int] = dict( + zip(TEXT_ATTRIBUTES, [0, 1, 2, 3, 4, 5, 7, 8]) + ) fg = FOREGROUND.get(fg_color, 39) bg = BACKGROUND.get(bg_color, 49) att = ATTRIBUTE.get(attribute, 0) @@ -2332,7 +2376,7 @@ def print_color_examples(): print("") -def group_iterator(items): +def group_iterator(items: np.ndarray) -> Generator[Tuple[Any, np.ndarray], None, None]: """An iterator to step over identical items in a :class:`numpy.ndarray` Example @@ -2349,7 +2393,7 @@ def group_iterator(items): yield item, np.where(items == item)[0] -def compute_hash(filename): +def compute_hash(filename: Union[str, Path, IO[bytes]]) -> bytes: """Compute a unique hash of a file. This is designed to keep around to detect changes, not to be @@ -2378,9 +2422,10 @@ def compute_hash(filename): return h.digest() -def get_conjunction(coord, t0, precision="low", ecl="IERS2010"): - """ - Find first time of Solar conjuction after t0 and approximate elongation at conjunction +def get_conjunction( + coord: coords.SkyCoord, t0: Time, precision: str = "low", ecl: str = "IERS2010" +) -> Tuple[Time, u.Quantity]: + """Find first time of Solar conjuction after t0 and approximate elongation at conjunction. Offers a low-precision version (based on analytic expression of Solar longitude) Or a higher-precision version (based on interpolating :func:`astropy.coordinates.get_sun`) @@ -2445,9 +2490,8 @@ def get_conjunction(coord, t0, precision="low", ecl="IERS2010"): return conjunction, csun.separation(coord) -def divide_times(t, t0, offset=0.5): - """ - Divide input times into years relative to t0 +def divide_times(t: Time, t0: Time, offset: float = 0.5) -> np.ndarray: + """Divide input times into years relative to t0. Years are centered around the requested offset value @@ -2479,7 +2523,7 @@ def divide_times(t, t0, offset=0.5): """ dt = t - t0 values = (dt.to(u.yr).value + offset) // 1 - return np.digitize(values, np.unique(values), right=True) + return cast(np.ndarray, np.digitize(values, np.unique(values), right=True)) def convert_dispersion_measure(dm, dmconst=None): @@ -2735,8 +2779,8 @@ def woodbury_dot(Ndiag, U, Phidiag, x, y): def _get_wx2pl_lnlike(model, component_name, ignore_fyr=True): - from pint.models.noise_model import powerlaw from pint import DMconst + from pint.models.noise_model import powerlaw assert component_name in ["WaveX", "DMWaveX"] prefix = "WX" if component_name == "WaveX" else "DMWX" From 1b3b20f26e50e52b2135f256f8cb263424b77301 Mon Sep 17 00:00:00 2001 From: Anne Archibald Date: Tue, 5 Mar 2024 20:34:29 +0000 Subject: [PATCH 28/28] list -> List --- src/pint/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pint/utils.py b/src/pint/utils.py index 3b72c175b..78a25a5d8 100644 --- a/src/pint/utils.py +++ b/src/pint/utils.py @@ -116,7 +116,7 @@ "get_unit", ] -COLOR_NAMES: list[str] = [ +COLOR_NAMES: List[str] = [ "black", "red", "green", @@ -126,7 +126,7 @@ "cyan", "white", ] -TEXT_ATTRIBUTES: list[str] = [ +TEXT_ATTRIBUTES: List[str] = [ "normal", "bold", "subdued",