Skip to content

Commit

Permalink
Fix emission probability matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jul 3, 2024
1 parent d90badc commit 8002bd4
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 32 deletions.
44 changes: 31 additions & 13 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,44 +303,62 @@ def get_emission_probability_haploid(ref_allele, query_allele, site, emission_ma


@jit.numba_njit
def get_emission_matrix_haploid_hkylike(mu, kappa=None):
def get_emission_matrix_haploid_tstv(mu, kappa=None):
"""
Return an emission probability matrix that allows for mutational bias
towards transitions or transversions.
Transition and transversion probabilities are defined such that
the probability of a particular type of transition is equal to
`kappa` * the probability of a particular type of transversion,
and that the total probability of mutation is equal to `mu`.
When `kappa` is set to None, it defaults to 1.
:param float mu: Probability of mutation to any allele.
:param float kappa: Transition-to-transversion rate ratio.
"""
if mu < 0.0 or mu > 1.0:
if np.any(mu < 0.0) or np.any(mu > 1.0):
raise ValueError("Probability of mutation must be in [0, 1].")
if kappa is not None:
if kappa <= 0:
raise ValueError("Transition-to-transversion rate ratio must be positive.")

if kappa is not None and kappa <= 0:
raise ValueError("Transition-to-transversion rate ratio must be positive.")

if kappa is None:
kappa = 1.0

num_sites = len(mu)
num_alleles = 4 # Assume that ACGT are encoded as 0 to 3.

# Initialise emission probability matrix with zeros.
emission_matrix = (
np.zeros((num_sites, num_alleles, num_alleles), dtype=np.float64) - 1
)

# Define transitions: A <-> G and C <-> T.
transitions = [(0, 2), (2, 0), (1, 3), (3, 1)]

for i in range(num_sites):
for j in range(num_alleles):
for k in range(num_alleles):
if j == k:
emission_matrix[i, j, k] = 1 - mu[i]
else:
emission_matrix[i, j, k] = mu[i] / 3
if kappa is not None:
# Transitions: A <-> G, C <-> T.
is_transition_AG = i in [0, 2] and j in [0, 2]
is_transition_CT = i in [1, 3] and j in [1, 3]
if is_transition_AG or is_transition_CT:
emission_matrix[i, j, k] *= kappa
mu_over_two_plus_kappa = mu[i] / (2.0 + kappa)
emission_matrix[i, j, k] = mu_over_two_plus_kappa
if (j, k) in transitions:
emission_matrix[i, j, k] *= kappa

row_sum = np.sum(emission_matrix[i, j, :], )
if not np.isclose(row_sum, 1.0):
err_msg = f"Row values must sum to one. {row_sum}"
raise ValueError(err_msg)

return emission_matrix


@jit.numba_njit
def get_emission_probability_haploid_hkylike(
def get_emission_probability_haploid_tstv(
ref_allele, query_allele, site, emission_matrix
):
"""
Expand Down
38 changes: 19 additions & 19 deletions tests/test_nontree_vit_haploid_tstv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def verify(self, ts, include_ancestors):
r_s = [
np.zeros(m) + 0.01,
np.random.rand(m),
1e-5 * (np.random.rand(m - 1) + 0.5) / 2,
np.zeros(m - 1) + 0.2,
np.zeros(m - 1) + 1e-6,
1e-5 * (np.random.rand(m) + 0.5) / 2,
np.zeros(m) + 0.2,
np.zeros(m) + 1e-6,
]
mu_s = [
np.zeros(m) + 0.01,
Expand All @@ -29,10 +29,10 @@ def verify(self, ts, include_ancestors):
np.zeros(m) + 0.2,
np.zeros(m) + 1e-6,
]
kappa_s = [1 / 2, 1 / 4, 1, 3 / 2, 2]
kappa_s = [0.25, 0.5, 1.0, 1.5, 2.0]

for s, r, mu, kappa in itertools.product(queries, r_s, mu_s, kappa_s):
e = core.get_emission_matrix_haploid_hkylike(mu, kappa)
e = core.get_emission_matrix_haploid_tstv(mu, kappa)

V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(
n=n,
Expand All @@ -41,7 +41,7 @@ def verify(self, ts, include_ancestors):
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_hkylike,
emission_func=core.get_emission_probability_haploid_tstv,
)
path_vs = vh.backwards_viterbi_hap(m=m, V_last=V_vs[m - 1, :], P=P_vs)
ll_check = vh.path_ll_hap(
Expand All @@ -52,7 +52,7 @@ def verify(self, ts, include_ancestors):
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_hkylike,
emission_func=core.get_emission_probability_haploid_tstv,
)
self.assertAllClose(ll_vs, ll_check)

Expand All @@ -63,7 +63,7 @@ def verify(self, ts, include_ancestors):
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_hkylike,
emission_func=core.get_emission_probability_haploid_tstv,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp[m - 1, :], P=P_tmp)
ll_check = vh.path_ll_hap(
Expand All @@ -74,7 +74,7 @@ def verify(self, ts, include_ancestors):
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_hkylike,
emission_func=core.get_emission_probability_haploid_tstv,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)
Expand All @@ -86,7 +86,7 @@ def verify(self, ts, include_ancestors):
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_hkylike,
emission_func=core.get_emission_probability_haploid_tstv,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
Expand All @@ -97,7 +97,7 @@ def verify(self, ts, include_ancestors):
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_hkylike,
emission_func=core.get_emission_probability_haploid_tstv,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)
Expand All @@ -109,7 +109,7 @@ def verify(self, ts, include_ancestors):
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_hkylike,
emission_func=core.get_emission_probability_haploid_tstv,
)
path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
ll_check = vh.path_ll_hap(
Expand All @@ -120,7 +120,7 @@ def verify(self, ts, include_ancestors):
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_hkylike,
emission_func=core.get_emission_probability_haploid_tstv,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)
Expand All @@ -132,7 +132,7 @@ def verify(self, ts, include_ancestors):
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_hkylike,
emission_func=core.get_emission_probability_haploid_tstv,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
Expand All @@ -143,7 +143,7 @@ def verify(self, ts, include_ancestors):
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_hkylike,
emission_func=core.get_emission_probability_haploid_tstv,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)
Expand All @@ -155,7 +155,7 @@ def verify(self, ts, include_ancestors):
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_hkylike,
emission_func=core.get_emission_probability_haploid_tstv,
)
path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
ll_check = vh.path_ll_hap(
Expand All @@ -166,7 +166,7 @@ def verify(self, ts, include_ancestors):
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_hkylike,
emission_func=core.get_emission_probability_haploid_tstv,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)
Expand All @@ -183,7 +183,7 @@ def verify(self, ts, include_ancestors):
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_hkylike,
emission_func=core.get_emission_probability_haploid_tstv,
)
path_tmp = vh.backwards_viterbi_hap_no_pointer(
m=m,
Expand All @@ -198,7 +198,7 @@ def verify(self, ts, include_ancestors):
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_hkylike,
emission_func=core.get_emission_probability_haploid_tstv,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)
Expand Down

0 comments on commit 8002bd4

Please sign in to comment.