Skip to content

Commit

Permalink
merged testing
Browse files Browse the repository at this point in the history
  • Loading branch information
astheeggeggs committed Jul 10, 2023
1 parent 1e23980 commit 45afaa5
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 6 deletions.
239 changes: 234 additions & 5 deletions python/tests/test_haplotype_matching_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def update_probabilities(self, site, haplotype_state):
match = (
haplotype_state == MISSING or haplotype_state == allelic_state[v]
)
st.value = self.compute_next_probability(site.id, st.value, match)
st.value = self.compute_next_probability(site.id, st.value, match, u)

# Unset the states
allelic_state[tree.root] = -1
Expand All @@ -329,7 +329,7 @@ def update_probabilities(self, site, haplotype_state):

def process_site(self, site, haplotype_state, forwards=True):
if forwards:
# Forwards algorithm
# Forwards algorithm, or forwards pass in Viterbi
self.update_probabilities(site, haplotype_state)
self.compress()
s = self.compute_normalisation_factor()
Expand Down Expand Up @@ -385,7 +385,7 @@ def run_backward(self, h):
def compute_normalisation_factor(self):
raise NotImplementedError()

def compute_next_probability(self, site_id, p_last, is_match):
def compute_next_probability(self, site_id, p_last, is_match, node):
raise NotImplementedError()


Expand Down Expand Up @@ -447,6 +447,86 @@ class BackwardMatrix(CompressedMatrix):
"""Class representing a compressed backward matrix."""


class ViterbiMatrix(CompressedMatrix):
"""
Class representing the compressed Viterbi matrix.
"""

def __init__(self, ts):
super().__init__(ts)
# Tuple containing the site, the node in the tree, and whether
# recombination is required
self.recombination_required = [(-1, 0, False)]

def add_recombination_required(self, site, node, required):
self.recombination_required.append((site, node, required))

def choose_sample(self, site_id, tree):
max_value = -1
u = -1
for node, value in self.value_transitions[site_id]:
if value > max_value:
max_value = value
u = node
assert u != -1

transition_nodes = [u for (u, _) in self.value_transitions[site_id]]
while not tree.is_sample(u):
for v in tree.children(u):
if v not in transition_nodes:
u = v
break
else:
raise AssertionError("could not find path")
return u

def traceback(self):
# Run the traceback.
m = self.ts.num_sites
match = np.zeros(m, dtype=int)
recombination_tree = np.zeros(self.ts.num_nodes, dtype=int) - 1
tree = tskit.Tree(self.ts)
tree.last()
current_node = -1

rr_index = len(self.recombination_required) - 1
for site in reversed(self.ts.sites()):
while tree.interval.left > site.position:
tree.prev()
assert tree.interval.left <= site.position < tree.interval.right

# Fill in the recombination tree
j = rr_index
while self.recombination_required[j][0] == site.id:
u, required = self.recombination_required[j][1:]
recombination_tree[u] = required
j -= 1

if current_node == -1:
current_node = self.choose_sample(site.id, tree)
match[site.id] = current_node

# Now traverse up the tree from the current node. The first marked node
# we meet tells us whether we need to recombine.
u = current_node
while u != -1 and recombination_tree[u] == -1:
u = tree.parent(u)

assert u != -1
if recombination_tree[u] == 1:
# Need to switch at the next site.
current_node = -1
# Reset the nodes in the recombination tree.
j = rr_index
while self.recombination_required[j][0] == site.id:
u, required = self.recombination_required[j][1:]
recombination_tree[u] = -1
j -= 1
rr_index = j

return match


class ForwardAlgorithm(LsHmmAlgorithm):
"""Runs the Li and Stephens forward algorithm."""

Expand All @@ -472,7 +552,9 @@ def compute_normalisation_factor(self):
s += self.N[j] * st.value
return s

def compute_next_probability(self, site_id, p_last, is_match):
def compute_next_probability(
self, site_id, p_last, is_match, node
): # Note node only used in Viterbi
rho = self.rho[site_id]
mu = self.mu[site_id]
n = self.ts.num_samples
Expand Down Expand Up @@ -541,7 +623,9 @@ def compute_normalisation_factor(self):
s += self.N[j] * st.value
return s

def compute_next_probability(self, site_id, p_next, is_match):
def compute_next_probability(
self, site_id, p_next, is_match, node
): # Note node only used in Viterbi
mu = self.mu[site_id]
n_alleles = self.n_alleles[site_id]

Expand All @@ -564,6 +648,82 @@ def compute_next_probability(self, site_id, p_next, is_match):
return p_next * p_e


class ViterbiAlgorithm(LsHmmAlgorithm):
"""
Runs the Li and Stephens Viterbi algorithm.
"""

def __init__(
self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10
):
super().__init__(
ts,
rho,
mu,
alleles,
n_alleles,
precision=precision,
scale_mutation=scale_mutation,
)
self.output = ViterbiMatrix(ts)

def compute_normalisation_factor(self):
max_st = ValueTransition(value=-1)
for st in self.T:
assert st.tree_node != tskit.NULL
if st.value > max_st.value:
max_st = st
if max_st.value == 0:
raise ValueError(
"Trying to match non-existent allele with zero mutation rate"
)
return max_st.value

def compute_next_probability(self, site_id, p_last, is_match, node):
rho = self.rho[site_id]
mu = self.mu[site_id]
n = self.ts.num_samples
n_alleles = self.n_alleles[site_id]

p_no_recomb = p_last * (1 - rho + rho / n)
p_recomb = rho / n
recombination_required = False
if p_no_recomb > p_recomb:
p_t = p_no_recomb
else:
p_t = p_recomb
recombination_required = True
self.output.add_recombination_required(site_id, node, recombination_required)

if self.scale_mutation_based_on_n_alleles:
if is_match:
# Scale mutation based on the number of alleles
# - so the mutation rate is the mutation rate to one of the
# alleles. The overall mutation rate is then
# (n_alleles - 1) * mutation_rate.
p_e = 1 - (n_alleles - 1) * mu
else:
p_e = mu - mu * (n_alleles == 1)
# Added boolean in case we're at an invariant site
else:
# No scaling based on the number of alleles
# - so the mutation rate is the mutation rate to anything.
# This means that we must rescale the mutation rate to a different
# allele, by the number of alleles.
if n_alleles == 1: # In case we're at an invariant site
if is_match:
p_e = 1
else:
p_e = 0
else:
if is_match:
p_e = 1 - mu
else:
p_e = mu / (n_alleles - 1)

return p_t * p_e


def ls_forward_tree(
h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False
):
Expand Down Expand Up @@ -636,6 +796,43 @@ def ls_backward_tree(
return ba.run_backward(h)


def ls_viterbi_tree(
h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False
):
if alleles is None:
n_alleles = np.int8(
[
len(np.unique(np.append(ts.genotype_matrix()[j, :], h[j])))
for j in range(ts.num_sites)
]
)
alleles = tskit.ALLELES_ACGT
if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0:
alleles = tskit.ALLELES_01
if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0:
raise ValueError(
"""Alleles list could not be identified.
Please pass a list of lists of alleles of length m,
or a list of alleles (e.g. tskit.ALLELES_ACGT)"""
)
alleles = [alleles for _ in range(ts.num_sites)]
else:
alleles, n_alleles = check_alleles(alleles, ts.num_sites)
"""
Viterbi path computation based on a tree sequence.
"""
va = ViterbiAlgorithm(
ts,
rho,
mu,
alleles,
n_alleles,
precision=precision,
scale_mutation=scale_mutation_based_on_n_alleles,
)
return va.run_forward(h)


class LSBase:
"""Superclass of Li and Stephens tests."""

Expand Down Expand Up @@ -788,6 +985,10 @@ class FBAlgorithmBase(LSBase):
"""Base for forwards backwards algorithm tests."""


class VitAlgorithmBase(LSBase):
"""Base for viterbi algoritm tests."""


class TestMirroringHap(FBAlgorithmBase):
"""Tests that mirroring the tree sequence and running forwards and backwards
algorithms gives the same log-likelihood of observing the data."""
Expand Down Expand Up @@ -889,3 +1090,31 @@ def verify(self, ts):
self.assertAllClose(B, B_tree)
self.assertAllClose(F, F_tree)
self.assertAllClose(ll, ll_tree)


class TestTreeViterbiHap(VitAlgorithmBase):
"""Test that we have the same log-likelihood between tree and matrix
implementations"""

def verify(self, ts):
for n, H, s, r, mu in self.example_parameters_haplotypes(ts):
path, ll = ls.viterbi(
H, s, r, mutation_rate=mu, scale_mutation_based_on_n_alleles=False
)
ts_check = ts.simplify(range(1, n + 1), filter_sites=False)
cm = ls_viterbi_tree(s[0, :], ts_check, r, mu)
ll_tree = np.sum(np.log10(cm.normalisation_factor))
self.assertAllClose(ll, ll_tree)

# Now, need to ensure that the likelihood of the preferred path is
# the same as ll_tree (and ll).
path_tree = cm.traceback()
ll_check = ls.path_ll(
H,
s,
path_tree,
r,
mutation_rate=mu,
scale_mutation_based_on_n_alleles=False,
)
self.assertAllClose(ll, ll_check)
2 changes: 1 addition & 1 deletion python/tests/test_haplotype_matching_viterbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Python implementation of the Li and Stephens Viterbi algorithm.
Python implementation of the Li and Stephens Viterbi algorithms.
"""
import itertools
import unittest
Expand Down

0 comments on commit 45afaa5

Please sign in to comment.