Skip to content

Commit

Permalink
Merge pull request #41 from Billyzhang1229/feature/maximum_likelihood
Browse files Browse the repository at this point in the history
Likelihood Calculation Felsenstein
  • Loading branch information
jeromekelleher authored May 10, 2024
2 parents c2c2b8d + 4896bec commit 080d0f3
Show file tree
Hide file tree
Showing 12 changed files with 576 additions and 23 deletions.
79 changes: 79 additions & 0 deletions examples/likelihood.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"\n",
"import msprime\n",
"import numpy as np\n",
"\n",
"# import local phylokit modules\n",
"phylokit_path = os.path.abspath(os.path.join(os.pardir))\n",
"if phylokit_path not in sys.path:\n",
" sys.path.append(phylokit_path)\n",
"\n",
"import phylokit as pk"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def simulate_ts(num_samples, sequence_length, mutation_rate, seed=1234):\n",
" tsa = msprime.sim_ancestry(\n",
" num_samples, sequence_length=sequence_length, ploidy=1, random_seed=seed\n",
" )\n",
" return msprime.sim_mutations(tsa, rate=mutation_rate, random_seed=seed)\n",
"\n",
"def create_mutation_tree(num_samples, sequence_length, mutation_rate, seed=1234):\n",
" ts_in = simulate_ts(num_samples, sequence_length, mutation_rate, seed=seed)\n",
" pk_mts = pk.parsimony.hartigan.ts_to_dataset(ts_in)\n",
" ds_in = pk.from_tskit(ts_in.first())\n",
" ds = ds_in.merge(pk_mts)\n",
" return ds\n",
"\n",
"pk_mts = create_mutation_tree(10000, 1000, 0.001, seed=1234)\n",
"pk_mts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mutation_rate = 0.001\n",
"\n",
"pk.likelihood_felsenstein(pk_mts, rate=mutation_rate)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "phylokit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 2 additions & 0 deletions phylokit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .distance import kc_distance
from .distance import mrca
from .distance import rf_distance
from .maximum_likelihood.felsenstein import likelihood_felsenstein
from .parsimony.hartigan import append_parsimony_score
from .parsimony.hartigan import get_hartigan_parsimony_score
from .parsimony.hartigan import numba_hartigan_parsimony_vectorised
Expand Down Expand Up @@ -53,4 +54,5 @@
"numba_hartigan_parsimony_vectorised",
"get_hartigan_parsimony_score",
"append_parsimony_score",
"likelihood_felsenstein",
]
14 changes: 8 additions & 6 deletions phylokit/balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import util


@jit.numba_njit
@jit.numba_njit()
def _sackin_index(virtual_root, left_child, right_sib):
stack = []
root = left_child[virtual_root]
Expand Down Expand Up @@ -43,7 +43,7 @@ def sackin_index(ds):
return _sackin_index(-1, ds.node_left_child.data, ds.node_right_sib.data)


@jit.numba_njit
@jit.numba_njit()
def _colless_index(postorder, left_child, right_sib):
num_leaves = np.zeros_like(left_child)
total = 0.0
Expand Down Expand Up @@ -83,11 +83,13 @@ def colless_index(ds):
if util.get_num_roots(ds) != 1:
raise ValueError("Colless index not defined for multiroot trees")
return _colless_index(
ds.traversal_postorder.data, ds.node_left_child.data, ds.node_right_sib.data
ds.traversal_postorder.data,
ds.node_left_child.data,
ds.node_right_sib.data,
)


@jit.numba_njit
@jit.numba_njit()
def _b1_index(postorder, left_child, right_sib, parent):
max_path_length = np.zeros_like(postorder)
total = 0.0
Expand Down Expand Up @@ -121,7 +123,7 @@ def b1_index(ds):
)


@jit.numba_njit
@jit.numba_njit()
def general_log(x, base):
"""
Compute the logarithm of x in base `base`.
Expand All @@ -134,7 +136,7 @@ def general_log(x, base):
return math.log(x) / math.log(base)


@jit.numba_njit
@jit.numba_njit()
def _b2_index(virtual_root, left_child, right_sib, base):
root = left_child[virtual_root]
stack = [(root, 1)]
Expand Down
4 changes: 2 additions & 2 deletions phylokit/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from . import util


@jit.numba_njit
@jit.numba_njit()
def _mrca(parent, time, u, v):
tu = time[u]
tv = time[v]
Expand Down Expand Up @@ -52,7 +52,7 @@ def mrca(ds, u, v):
return _mrca(ds.node_parent.data, ds.node_time.data, u, v)


@jit.numba_njit
@jit.numba_njit()
def _kc_distance(samples, ds1, ds2):
# ds1 and ds2 are tuples of the form (parent_array, time_array, branch_length, root)
n = samples.shape[0]
Expand Down
2 changes: 1 addition & 1 deletion phylokit/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import jit


@jit.numba_njit
@jit.numba_njit()
def _linkage_matrix_to_dataset(Z):
n = Z.shape[0] + 1
N = 2 * n
Expand Down
19 changes: 14 additions & 5 deletions phylokit/jit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import logging
import os

Expand Down Expand Up @@ -30,8 +31,16 @@
}


def numba_njit(func, **kwargs):
if ENABLE_NUMBA: # pragma: no cover
return numba.jit(func, **{**DEFAULT_NUMBA_ARGS, **kwargs})
else:
return func
def numba_njit(**numba_kwargs):
def _numba_njit(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs) # pragma: no cover

if ENABLE_NUMBA: # pragma: no cover
combined_kwargs = {**DEFAULT_NUMBA_ARGS, **numba_kwargs}
return numba.jit(**combined_kwargs)(func)
else:
return func

return _numba_njit
Empty file.
Loading

0 comments on commit 080d0f3

Please sign in to comment.