Skip to content

Commit

Permalink
Allow the numba cache to be used, for development
Browse files Browse the repository at this point in the history
Copied from SGKIT. Fixes tskit-dev#438
  • Loading branch information
hyanwong committed Dec 1, 2024
1 parent 2b3412a commit 5f8a8ca
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 74 deletions.
26 changes: 26 additions & 0 deletions tsdate/accelerate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
from typing import Callable

from numba import jit

# By default we disable the numba cache. See
_DISABLE_CACHE = os.environ.get("TSDATE_DISABLE_NUMBA_CACHE", "1")

try:
CACHE_NUMBA = {"0": True, "1": False}[_DISABLE_CACHE]
except KeyError as e: # pragma: no cover
raise KeyError(
"Environment variable 'TSDATE_DISABLE_NUMBA_CACHE' must be '0' or '1'"
) from e


DEFAULT_NUMBA_ARGS = {
"nopython": True,
"cache": CACHE_NUMBA,
}


def numba_jit(*args, **kwargs) -> Callable: # pragma: no cover
kwargs_ = DEFAULT_NUMBA_ARGS.copy()
kwargs_.update(kwargs)
return jit(*args, **kwargs_)
77 changes: 39 additions & 38 deletions tsdate/approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import numpy as np

from . import hypergeo
from .accelerate import numba_jit

# TODO: these are reasonable defaults but could
# be set via a control dict
Expand Down Expand Up @@ -70,7 +71,7 @@ class KLMinimizationFailedError(Exception):
pass


@numba.njit(_unituple(_f, 3)(_f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f))
def approximate_log_moments(mean, variance):
"""
Approximate log moments via a second-order Taylor series expansion around
Expand All @@ -88,7 +89,7 @@ def approximate_log_moments(mean, variance):
return logx, xlogx, logx2


@numba.njit(_unituple(_f, 2)(_f, _f))
@numba_jit(_unituple(_f, 2)(_f, _f))
def approximate_gamma_kl(x, logx):
"""
Use Newton root finding to get gamma natural parameters matching the sufficient
Expand Down Expand Up @@ -126,7 +127,7 @@ def approximate_gamma_kl(x, logx):
return alpha - 1.0, alpha / x


@numba.njit(_unituple(_f, 2)(_f, _f))
@numba_jit(_unituple(_f, 2)(_f, _f))
def approximate_gamma_mom(mean, variance):
"""
Use the method of moments to approximate a distribution with a gamma of the
Expand All @@ -139,7 +140,7 @@ def approximate_gamma_mom(mean, variance):
return shape - 1.0, rate


@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f))
@numba_jit(_unituple(_f, 2)(_f, _f, _f, _f))
def approximate_gamma_iqr(q1, q2, x1, x2):
"""Find gamma natural parameters that match empirical quantiles"""
if not (q2 > q1 and x2 > x1):
Expand Down Expand Up @@ -177,7 +178,7 @@ def approximate_gamma_iqr(q1, q2, x1, x2):
return alpha - 1, beta


@numba.njit(_unituple(_f, 2)(_f1r, _f1r))
@numba_jit(_unituple(_f, 2)(_f1r, _f1r))
def average_gammas(alpha, beta):
"""
Given natural parameters for a set of gammas, average sufficient
Expand All @@ -195,7 +196,7 @@ def average_gammas(alpha, beta):
return approximate_gamma_kl(avg_x, avg_logx)


@numba.njit(_b(_f, _f))
@numba_jit(_b(_f, _f))
def _valid_moments(mn, va):
if not (np.isfinite(mn) and np.isfinite(va)):
return False
Expand All @@ -204,7 +205,7 @@ def _valid_moments(mn, va):
return True


@numba.njit(_b(_f, _f))
@numba_jit(_b(_f, _f))
def _valid_gamma(s, r):
if not (np.isfinite(s) and np.isfinite(r)):
return False
Expand All @@ -213,7 +214,7 @@ def _valid_gamma(s, r):
return True


@numba.njit(_b(_f, _f, _f))
@numba_jit(_b(_f, _f, _f))
def _valid_hyp1f1(a, b, z):
if not (np.isfinite(a) and np.isfinite(b) and np.isfinite(z)):
return False
Expand All @@ -222,7 +223,7 @@ def _valid_hyp1f1(a, b, z):
return True


@numba.njit(_b(_f, _f, _f))
@numba_jit(_b(_f, _f, _f))
def _valid_hyperu(a, b, z):
if not (np.isfinite(a) and np.isfinite(b) and np.isfinite(z)):
return False
Expand All @@ -233,7 +234,7 @@ def _valid_hyperu(a, b, z):
return True


@numba.njit(_b(_f, _f, _f, _f))
@numba_jit(_b(_f, _f, _f, _f))
def _valid_hyp2f1(a, b, c, z):
if not (np.isfinite(a) and np.isfinite(b) and np.isfinite(c)):
return False
Expand All @@ -247,7 +248,7 @@ def _valid_hyp2f1(a, b, c, z):
# --- various EP updates --- #


@numba.njit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f))
def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_i, t_j) := \
Expand Down Expand Up @@ -289,7 +290,7 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
return logl, mn_i, va_i, mn_j, va_j


@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij):
r"""
log p(t_i) := \
Expand Down Expand Up @@ -331,7 +332,7 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij):
return logl, mn_i, va_i


@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_j) := \
Expand Down Expand Up @@ -365,7 +366,7 @@ def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij):
return logl, mn_j, va_j


@numba.njit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f))
def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_i, t_j) := \
Expand Down Expand Up @@ -407,7 +408,7 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
return logl, mn_i, va_i, mn_j, va_j


@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f))
def twin_moments(a_i, b_i, y_ij, mu_ij):
r"""
log p(t_i) := \
Expand All @@ -424,7 +425,7 @@ def twin_moments(a_i, b_i, y_ij, mu_ij):
return logl, mn_i, va_i


@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
def sideways_moments(t_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_j) := \
Expand Down Expand Up @@ -454,7 +455,7 @@ def sideways_moments(t_i, a_j, b_j, y_ij, mu_ij):
return logl, mn_j, va_j


@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 2)(_f, _f, _f, _f, _f, _f))
def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_m, t_i, t_j) = \
Expand Down Expand Up @@ -496,7 +497,7 @@ def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
return mn_m, va_m


@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 2)(_f, _f, _f, _f, _f))
def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij):
r"""
log p(t_m, t_i) := \
Expand All @@ -515,7 +516,7 @@ def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij):
return mn_m, va_m


@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 2)(_f, _f, _f, _f, _f))
def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_m, t_j) := \
Expand All @@ -534,7 +535,7 @@ def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij):
return mn_m, va_m


@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f, _f))
def mutation_unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_m, t_i, t_j) := \
Expand Down Expand Up @@ -587,7 +588,7 @@ def mutation_unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
return pr_m, mn_m, va_m


@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f))
def mutation_twin_moments(a_i, b_i, y_ij, mu_ij):
r"""
log p(t_m, t_i) := \
Expand All @@ -606,7 +607,7 @@ def mutation_twin_moments(a_i, b_i, y_ij, mu_ij):
return pr_m, mn_m, va_m


@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
def mutation_sideways_moments(t_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_m, t_j) := \
Expand Down Expand Up @@ -655,7 +656,7 @@ def mutation_sideways_moments(t_i, a_j, b_j, y_ij, mu_ij):
return pr_m, mn_m, va_m


@numba.njit(_unituple(_f, 2)(_f, _f))
@numba_jit(_unituple(_f, 2)(_f, _f))
def mutation_edge_moments(t_i, t_j):
r"""
log p(t_m) := \
Expand All @@ -670,7 +671,7 @@ def mutation_edge_moments(t_i, t_j):
return mn_m, va_m


@numba.njit(_unituple(_f, 3)(_f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f))
def mutation_block_moments(t_i, t_j):
r"""
log p(t_m) := \
Expand All @@ -694,7 +695,7 @@ def mutation_block_moments(t_i, t_j):
# --- wrappers around updates --- #


@numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r))
def gamma_projection(pars_i, pars_j, pars_ij):
r"""
log p(t_i, t_j) := \
Expand All @@ -721,7 +722,7 @@ def gamma_projection(pars_i, pars_j, pars_ij):
return logl, np.array(proj_i), np.array(proj_j)


@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
def leafward_projection(t_i, pars_j, pars_ij):
r"""
log p(t_j) := \
Expand All @@ -744,7 +745,7 @@ def leafward_projection(t_i, pars_j, pars_ij):
return logl, np.array(proj_j)


@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
def rootward_projection(t_j, pars_i, pars_ij):
r"""
log p(t_i) := \
Expand All @@ -767,7 +768,7 @@ def rootward_projection(t_j, pars_i, pars_ij):
return logl, np.array(proj_i)


@numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r))
def unphased_projection(pars_i, pars_j, pars_ij):
r"""
log p(t_i, t_j) := \
Expand All @@ -794,7 +795,7 @@ def unphased_projection(pars_i, pars_j, pars_ij):
return logl, np.array(proj_i), np.array(proj_j)


@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f1r, _f1r))
def twin_projection(pars_i, pars_ij):
r"""
log p(t_i) := \
Expand All @@ -817,7 +818,7 @@ def twin_projection(pars_i, pars_ij):
return logl, np.array(proj_i)


@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
def sideways_projection(t_i, pars_j, pars_ij):
r"""
log p(t_j) := \
Expand All @@ -840,7 +841,7 @@ def sideways_projection(t_i, pars_j, pars_ij):
return logl, np.array(proj_j)


@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f1r, _f1r, _f1r))
def mutation_gamma_projection(pars_i, pars_j, pars_ij):
r"""
log p(t_m, t_i, t_j) = \
Expand All @@ -867,7 +868,7 @@ def mutation_gamma_projection(pars_i, pars_j, pars_ij):
return 1.0, np.array(proj_m)


@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
def mutation_leafward_projection(t_i, pars_j, pars_ij):
r"""
log p(t_m, t_j) := \
Expand All @@ -891,7 +892,7 @@ def mutation_leafward_projection(t_i, pars_j, pars_ij):
return 1.0, np.array(proj_m)


@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
def mutation_rootward_projection(t_j, pars_i, pars_ij):
r"""
log p(t_m, t_i) := \
Expand All @@ -915,7 +916,7 @@ def mutation_rootward_projection(t_j, pars_i, pars_ij):
return 1.0, np.array(proj_m)


@numba.njit(_tuple((_f, _f1r))(_f, _f))
@numba_jit(_tuple((_f, _f1r))(_f, _f))
def mutation_edge_projection(t_i, t_j):
r"""
log p(t_m) := \
Expand All @@ -933,7 +934,7 @@ def mutation_edge_projection(t_i, t_j):
return 1.0, np.array(proj_m)


@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f1r, _f1r, _f1r))
def mutation_unphased_projection(pars_i, pars_j, pars_ij):
r"""
log p(t_m, t_i, t_j) := \
Expand Down Expand Up @@ -961,7 +962,7 @@ def mutation_unphased_projection(pars_i, pars_j, pars_ij):
return pr_m, np.array(proj_m)


@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f1r, _f1r))
def mutation_twin_projection(pars_i, pars_ij):
r"""
log p(t_m, t_i) := \
Expand All @@ -985,7 +986,7 @@ def mutation_twin_projection(pars_i, pars_ij):
return pr_m, np.array(proj_m)


@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
def mutation_sideways_projection(t_i, pars_j, pars_ij):
r"""
log p(t_m, t_j) := \
Expand All @@ -1010,7 +1011,7 @@ def mutation_sideways_projection(t_i, pars_j, pars_ij):
return pr_m, np.array(proj_m)


@numba.njit(_tuple((_f, _f1r))(_f, _f))
@numba_jit(_tuple((_f, _f1r))(_f, _f))
def mutation_block_projection(t_i, t_j):
r"""
log p(t_m) := \
Expand Down
4 changes: 2 additions & 2 deletions tsdate/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import operator
from collections import defaultdict

import numba
import numpy as np
import scipy.stats
import tskit
from tqdm.auto import tqdm

from .accelerate import numba_jit
from .node_time_class import LIN_GRID, LOG_GRID


Expand Down Expand Up @@ -368,7 +368,7 @@ class LogLikelihoods(Likelihoods):
"""

@staticmethod
@numba.jit(nopython=True)
@numba_jit
def logsumexp(X):
alpha = -np.inf
r = 0.0
Expand Down
Loading

0 comments on commit 5f8a8ca

Please sign in to comment.