Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow the numba cache to be used, for development #441

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tsdate/accelerate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
from typing import Callable

from numba import jit

# By default we disable the numba cache. See
_DISABLE_CACHE = os.environ.get("TSDATE_DISABLE_NUMBA_CACHE", "1")

try:
CACHE_NUMBA = {"0": True, "1": False}[_DISABLE_CACHE]
except KeyError as e: # pragma: no cover
raise KeyError(
"Environment variable 'TSDATE_DISABLE_NUMBA_CACHE' must be '0' or '1'"
) from e


DEFAULT_NUMBA_ARGS = {
"nopython": True,
"cache": CACHE_NUMBA,
}


def numba_jit(*args, **kwargs) -> Callable: # pragma: no cover
kwargs_ = DEFAULT_NUMBA_ARGS.copy()
kwargs_.update(kwargs)
return jit(*args, **kwargs_)
31 changes: 16 additions & 15 deletions tsdate/approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import numpy as np

from . import hypergeo
from .accelerate import numba_jit

# TODO: these are reasonable defaults but could
# be set via a control dict
Expand Down Expand Up @@ -70,7 +71,7 @@ class KLMinimizationFailedError(Exception):
pass


@numba.njit(_unituple(_f, 3)(_f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f))
def approximate_log_moments(mean, variance):
"""
Approximate log moments via a second-order Taylor series expansion around
Expand All @@ -88,7 +89,7 @@ def approximate_log_moments(mean, variance):
return logx, xlogx, logx2


@numba.njit(_unituple(_f, 2)(_f, _f))
@numba_jit(_unituple(_f, 2)(_f, _f))
def approximate_gamma_kl(x, logx):
"""
Use Newton root finding to get gamma natural parameters matching the sufficient
Expand Down Expand Up @@ -126,7 +127,7 @@ def approximate_gamma_kl(x, logx):
return alpha - 1.0, alpha / x


@numba.njit(_unituple(_f, 2)(_f, _f))
@numba_jit(_unituple(_f, 2)(_f, _f))
def approximate_gamma_mom(mean, variance):
"""
Use the method of moments to approximate a distribution with a gamma of the
Expand Down Expand Up @@ -177,7 +178,7 @@ def approximate_gamma_iqr(q1, q2, x1, x2):
return alpha - 1, beta


@numba.njit(_unituple(_f, 2)(_f1r, _f1r))
@numba_jit(_unituple(_f, 2)(_f1r, _f1r))
def average_gammas(alpha, beta):
"""
Given natural parameters for a set of gammas, average sufficient
Expand All @@ -195,7 +196,7 @@ def average_gammas(alpha, beta):
return approximate_gamma_kl(avg_x, avg_logx)


@numba.njit(_b(_f, _f))
@numba_jit(_b(_f, _f))
def _valid_moments(mn, va):
if not (np.isfinite(mn) and np.isfinite(va)):
return False
Expand All @@ -204,7 +205,7 @@ def _valid_moments(mn, va):
return True


@numba.njit(_b(_f, _f))
@numba_jit(_b(_f, _f))
def _valid_gamma(s, r):
if not (np.isfinite(s) and np.isfinite(r)):
return False
Expand All @@ -213,7 +214,7 @@ def _valid_gamma(s, r):
return True


@numba.njit(_b(_f, _f, _f))
@numba_jit(_b(_f, _f, _f))
def _valid_hyp1f1(a, b, z):
if not (np.isfinite(a) and np.isfinite(b) and np.isfinite(z)):
return False
Expand All @@ -222,7 +223,7 @@ def _valid_hyp1f1(a, b, z):
return True


@numba.njit(_b(_f, _f, _f))
@numba_jit(_b(_f, _f, _f))
def _valid_hyperu(a, b, z):
if not (np.isfinite(a) and np.isfinite(b) and np.isfinite(z)):
return False
Expand All @@ -233,7 +234,7 @@ def _valid_hyperu(a, b, z):
return True


@numba.njit(_b(_f, _f, _f, _f))
@numba_jit(_b(_f, _f, _f, _f))
def _valid_hyp2f1(a, b, c, z):
if not (np.isfinite(a) and np.isfinite(b) and np.isfinite(c)):
return False
Expand Down Expand Up @@ -655,7 +656,7 @@ def mutation_sideways_moments(t_i, a_j, b_j, y_ij, mu_ij):
return pr_m, mn_m, va_m


@numba.njit(_unituple(_f, 2)(_f, _f))
@numba_jit(_unituple(_f, 2)(_f, _f))
def mutation_edge_moments(t_i, t_j):
r"""
log p(t_m) := \
Expand All @@ -670,7 +671,7 @@ def mutation_edge_moments(t_i, t_j):
return mn_m, va_m


@numba.njit(_unituple(_f, 3)(_f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f))
def mutation_block_moments(t_i, t_j):
r"""
log p(t_m) := \
Expand Down Expand Up @@ -915,7 +916,7 @@ def mutation_rootward_projection(t_j, pars_i, pars_ij):
return 1.0, np.array(proj_m)


@numba.njit(_tuple((_f, _f1r))(_f, _f))
@numba_jit(_tuple((_f, _f1r))(_f, _f))
def mutation_edge_projection(t_i, t_j):
r"""
log p(t_m) := \
Expand Down Expand Up @@ -961,7 +962,7 @@ def mutation_unphased_projection(pars_i, pars_j, pars_ij):
return pr_m, np.array(proj_m)


@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f1r, _f1r))
def mutation_twin_projection(pars_i, pars_ij):
r"""
log p(t_m, t_i) := \
Expand All @@ -985,7 +986,7 @@ def mutation_twin_projection(pars_i, pars_ij):
return pr_m, np.array(proj_m)


@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
def mutation_sideways_projection(t_i, pars_j, pars_ij):
r"""
log p(t_m, t_j) := \
Expand All @@ -1010,7 +1011,7 @@ def mutation_sideways_projection(t_i, pars_j, pars_ij):
return pr_m, np.array(proj_m)


@numba.njit(_tuple((_f, _f1r))(_f, _f))
@numba_jit(_tuple((_f, _f1r))(_f, _f))
def mutation_block_projection(t_i, t_j):
r"""
log p(t_m) := \
Expand Down
4 changes: 2 additions & 2 deletions tsdate/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import operator
from collections import defaultdict

import numba
import numpy as np
import scipy.stats
import tskit
from tqdm.auto import tqdm

from .accelerate import numba_jit
from .node_time_class import LIN_GRID, LOG_GRID


Expand Down Expand Up @@ -368,7 +368,7 @@ class LogLikelihoods(Likelihoods):
"""

@staticmethod
@numba.jit(nopython=True)
@numba_jit
def logsumexp(X):
alpha = -np.inf
r = 0.0
Expand Down
12 changes: 7 additions & 5 deletions tsdate/hypergeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import numpy as np
from numba.extending import get_cython_function_address

from .accelerate import numba_jit

_HYP2F1_TOL = 1e-10
_HYP2F1_MAXTERM = int(1e6)

Expand Down Expand Up @@ -92,7 +94,7 @@ def _erf_inv(x):
return _erfinv_f8(x)


@numba.njit("f8(f8)")
@numba_jit("f8(f8)")
def _digamma(x):
"""
Digamma (psi) function, from asymptotic series expansion.
Expand All @@ -116,7 +118,7 @@ def _digamma(x):
)


@numba.njit("f8(f8)")
@numba_jit("f8(f8)")
def _trigamma(x):
"""
Trigamma function, from asymptotic series expansion
Expand Down Expand Up @@ -147,7 +149,7 @@ def _betaln(p, q):
return _gammaln(p) + _gammaln(q) - _gammaln(p + q)


@numba.njit("UniTuple(f8, 2)(f8, f8, f8)")
@numba_jit("UniTuple(f8, 2)(f8, f8, f8)")
def _hyperu_laplace(a, b, x):
"""
Approximate Tricomi's confluent hypergeometric function with real
Expand Down Expand Up @@ -175,7 +177,7 @@ def _hyperu_laplace(a, b, x):
return g - log(r) / 2, (dg - dr) * du - u


@numba.njit("f8(f8, f8, f8)")
@numba_jit("f8(f8, f8, f8)")
def _hyp1f1_laplace(a, b, x):
"""
Approximate Kummer's confluent hypergeometric function with real arguments,
Expand Down Expand Up @@ -278,7 +280,7 @@ def _hyp2f1_laplace(a, b, c, x):
return f - log(r) / 2 + s


@numba.njit("f8(f8, f8)")
@numba_jit("f8(f8, f8)")
def _gammainc_der(p, x):
"""
Derivative of lower incomplete gamma function with regards to `p`.
Expand Down
8 changes: 4 additions & 4 deletions tsdate/phasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
Tools for phasing singleton mutations
"""

import numba
import numpy as np
import tskit

from .accelerate import numba_jit
from .approx import _b1r, _b2r, _f, _f1r, _f2w, _i1r, _i1w, _i2r, _i2w, _tuple, _void

# --- machinery used by ExpectationPropagation class --- #


@numba.njit(_void(_f2w, _f1r, _i1r, _i2r))
@numba_jit(_void(_f2w, _f1r, _i1r, _i2r))
def reallocate_unphased(edges_likelihood, mutations_phase, mutations_block, blocks_edges):
"""
Add a proportion of each unphased singleton mutation to one of the two
Expand Down Expand Up @@ -64,7 +64,7 @@ def reallocate_unphased(edges_likelihood, mutations_phase, mutations_block, bloc
assert np.isclose(num_unphased, np.sum(edges_likelihood[edges_unphased, 0]))


@numba.njit(
@numba_jit(
_tuple((_f2w, _i2w, _i1w))(
_b1r, _i1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f
)
Expand Down Expand Up @@ -205,7 +205,7 @@ def block_singletons(ts, individuals_unphased):
)


@numba.njit(_i2w(_b2r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f))
@numba_jit(_i2w(_b2r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f))
def _mutation_frequency(
nodes_sample,
mutations_node,
Expand Down
15 changes: 8 additions & 7 deletions tsdate/rescaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import numpy as np
import tskit

from .accelerate import numba_jit
from .approx import (
_b,
_b1r,
Expand All @@ -48,7 +49,7 @@
from .util import mutation_span_array # NOQA: F401


@numba.njit(_i1w(_f1r, _i))
@numba_jit(_i1w(_f1r, _i))
def _fixed_changepoints(counts, epochs):
"""
Find breakpoints such that `counts` is divided roughly equally across `epochs`
Expand All @@ -65,7 +66,7 @@ def _fixed_changepoints(counts, epochs):
return e.astype(np.int32)


@numba.njit(_i1w(_f1r, _f1r, _f, _f, _f))
@numba_jit(_i1w(_f1r, _f1r, _f, _f, _f))
def _poisson_changepoints(counts, offset, penalty, min_counts, min_offset):
"""
Given Poisson counts and offsets for a sequence of observations, find the set
Expand Down Expand Up @@ -113,7 +114,7 @@ def f(i, j): # loss
return breaks


@numba.njit(
@numba_jit(
_tuple((_f2w, _i1w))(_b1r, _i1r, _f1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r, _f, _b)
)
def _count_mutations(
Expand Down Expand Up @@ -238,7 +239,7 @@ def count_mutations(ts, node_is_sample=None, size_biased=False):
)


@numba.njit(_tuple((_f1w, _f1w, _f1w, _i1w))(_f1r, _f2r, _i1r, _i1r))
@numba_jit(_tuple((_f1w, _f1w, _f1w, _i1w))(_f1r, _f2r, _i1r, _i1r))
def mutational_area(
nodes_time,
likelihoods,
Expand Down Expand Up @@ -295,7 +296,7 @@ def mutational_area(
return counts, offset, duration, nodes_index


# @numba.njit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _f1r, _i))
# @numba_jit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _f1r, _i))
# def mutational_timescale(
# nodes_time,
# likelihoods,
Expand Down Expand Up @@ -382,7 +383,7 @@ def mutational_area(
# return origin, adjust


@numba.njit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _i))
@numba_jit(_unituple(_f1w, 2)(_f1r, _f2r, _f2r, _i1r, _i1r, _i))
def mutational_timescale(
nodes_time,
likelihoods,
Expand Down Expand Up @@ -501,7 +502,7 @@ def rescale(x):
return new_posteriors


@numba.njit(_f1w(_f1r, _f1r, _f1r))
@numba_jit(_f1w(_f1r, _f1r, _f1r))
def piecewise_scale_point_estimate(
point_estimate,
original_breaks,
Expand Down
Loading