Skip to content

Commit

Permalink
added missingness to diploid viterbi
Browse files Browse the repository at this point in the history
  • Loading branch information
astheeggeggs committed Sep 28, 2022
1 parent 8d43085 commit c548679
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 44 deletions.
47 changes: 27 additions & 20 deletions python/tests/test_genotype_matching_fb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

MISSING = -1


def mirror_coordinates(ts):
"""
Returns a copy of the specified tree sequence in which all
Expand Down Expand Up @@ -249,6 +250,8 @@ def stupid_compress_dict(self):
u = tree.parent(u)
self.N[self.T_index[u]] += 1

print(self.T)

def update_tree(self):
"""
Update the internal data structures to move on to the next tree.
Expand Down Expand Up @@ -343,7 +346,7 @@ def update_tree(self):

# Also need to mark the corresponding InternalValueTransition as
# unused for the remaining states

for st2 in T:
if not (st2.value_list == tskit.NULL):
st2.value_list[T_index[vt.tree_node]].value = -1
Expand Down Expand Up @@ -452,7 +455,7 @@ def update_probabilities(self, site, genotype_state):
match,
template_is_het,
query_is_het,
query_is_missing
query_is_missing,
)

# This will ensure that allelic_state[:n] is filled
Expand Down Expand Up @@ -570,7 +573,14 @@ def compute_normalisation_factor_dict(self):
raise NotImplementedError()

def compute_next_probability_dict(
self, site_id, p_last, inner_summation, is_match, template_is_het, query_is_het, query_is_missing
self,
site_id,
p_last,
inner_summation,
is_match,
template_is_het,
query_is_het,
query_is_missing,
):
raise NotImplementedError()

Expand Down Expand Up @@ -679,7 +689,7 @@ def compute_next_probability_dict(
is_match,
template_is_het,
query_is_het,
query_is_missing
query_is_missing,
):
rho = self.rho[site_id]
mu = self.mu[site_id]
Expand All @@ -691,12 +701,11 @@ def compute_next_probability_dict(
+ (1 - rho) ** 2 * p_last
)

template_is_hom = np.logical_not(template_is_het)

if query_is_missing:
p_e = 1
else:
query_is_hom = np.logical_not(query_is_het)
template_is_hom = np.logical_not(template_is_het)

EQUAL_BOTH_HOM = np.logical_and(
np.logical_and(is_match, template_is_hom), query_is_hom
Expand All @@ -710,10 +719,10 @@ def compute_next_probability_dict(

p_e = (
EQUAL_BOTH_HOM * (1 - mu) ** 2
+ UNEQUAL_BOTH_HOM * (mu ** 2)
+ UNEQUAL_BOTH_HOM * (mu**2)
+ REF_HOM_OBS_HET * (2 * mu * (1 - mu))
+ REF_HET_OBS_HOM * (mu * (1 - mu))
+ BOTH_HET * ((1 - mu) ** 2 + mu ** 2)
+ BOTH_HET * ((1 - mu) ** 2 + mu**2)
)

return p_t * p_e
Expand Down Expand Up @@ -751,7 +760,7 @@ def compute_next_probability_dict(
is_match,
template_is_het,
query_is_het,
query_is_missing
query_is_missing,
):
mu = self.mu[site_id]

Expand All @@ -774,12 +783,12 @@ def compute_next_probability_dict(

p_e = (
EQUAL_BOTH_HOM * (1 - mu) ** 2
+ UNEQUAL_BOTH_HOM * (mu ** 2)
+ UNEQUAL_BOTH_HOM * (mu**2)
+ REF_HOM_OBS_HET * (2 * mu * (1 - mu))
+ REF_HET_OBS_HOM * (mu * (1 - mu))
+ BOTH_HET * ((1 - mu) ** 2 + mu ** 2)
+ BOTH_HET * ((1 - mu) ** 2 + mu**2)
)

return p_next * p_e


Expand All @@ -804,14 +813,13 @@ def genotype_emission(self, mu, m):
# Define the emission probability matrix
e = np.zeros((m, 8))
e[:, EQUAL_BOTH_HOM] = (1 - mu) ** 2
e[:, UNEQUAL_BOTH_HOM] = mu ** 2
e[:, BOTH_HET] = (1 - mu) ** 2 + mu ** 2
e[:, UNEQUAL_BOTH_HOM] = mu**2
e[:, BOTH_HET] = (1 - mu) ** 2 + mu**2
e[:, REF_HOM_OBS_HET] = 2 * mu * (1 - mu)
e[:, REF_HET_OBS_HOM] = mu * (1 - mu)

return e


def example_genotypes(self, ts):

H = ts.genotype_matrix()
Expand Down Expand Up @@ -840,10 +848,8 @@ def example_genotypes(self, ts):
for i in range(m):
G[i, :, :] = np.add.outer(H[i, :], H[i, :])


return H, G, genotypes


def example_parameters_genotypes(self, ts, seed=42):
np.random.seed(seed)
H, G, genotypes = self.example_genotypes(ts)
Expand Down Expand Up @@ -871,7 +877,6 @@ def example_parameters_genotypes(self, ts, seed=42):
e = self.genotype_emission(mu, m)
yield n, m, G, s, e, r, mu


def assertAllClose(self, A, B):
"""Assert that all entries of two matrices are 'close'"""
assert np.allclose(A, B, rtol=1e-5, atol=1e-8)
Expand Down Expand Up @@ -924,6 +929,7 @@ def verify(self, ts):
class FBAlgorithmBase(LSBase):
"""Base for forwards backwards algorithm tests."""


class TestMirroringDipdict(FBAlgorithmBase):
"""Tests that mirroring the tree sequence and running forwards and backwards
algorithms give the same log-likelihood of observing the data."""
Expand Down Expand Up @@ -952,7 +958,7 @@ def verify(self, ts):
ll_mirror_tree_dict = np.sum(np.log10(cm_mirror.normalisation_factor))

self.assertAllClose(ll_tree, ll_mirror_tree_dict)

# Ensure that the decoded matrices are the same
F_mirror_matrix, c, ll = ls.forwards(
np.flip(G_check, axis=0),
Expand Down Expand Up @@ -991,6 +997,7 @@ def verify(self, ts):
ll_tree = np.sum(np.log10(cm_d.normalisation_factor))
self.assertAllClose(ll, ll_tree)


class TestForwardBackwardTree(FBAlgorithmBase):
"""Tests that the tree algorithm computes the same forward matrix as the simple
method."""
Expand Down Expand Up @@ -1022,7 +1029,7 @@ def verify(self, ts):

# Note, need to remove the first sample from the ts, and ensure that
# invariant sites aren't removed.

ts_check = ts.simplify(range(1, n + 1), filter_sites=False)
c_f = ls_forward_tree(s[0, :], ts_check, r, mu)
ll_tree = np.sum(np.log10(c_f.normalisation_factor))
Expand Down
74 changes: 50 additions & 24 deletions python/tests/test_genotype_matching_viterbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
REF_HOM_OBS_HET = 1
REF_HET_OBS_HOM = 2

MISSING = -1


class ValueTransition:
"""Simple struct holding value transition values."""
Expand Down Expand Up @@ -390,6 +392,7 @@ def update_probabilities(self, site, genotype_state):
]

query_is_het = genotype_state == 1
query_is_missing = genotype_state == MISSING

for st1 in T:
u1 = st1.tree_node
Expand Down Expand Up @@ -423,6 +426,7 @@ def update_probabilities(self, site, genotype_state):
match,
template_is_het,
query_is_het,
query_is_missing,
u1,
u2,
)
Expand Down Expand Up @@ -486,6 +490,7 @@ def compute_next_probability_dict(
is_match,
template_is_het,
query_is_het,
query_is_missing,
node_1,
node_2,
):
Expand Down Expand Up @@ -830,6 +835,7 @@ def compute_next_probability_dict(
is_match,
template_is_het,
query_is_het,
query_is_missing,
node_1,
node_2,
):
Expand All @@ -841,26 +847,28 @@ def compute_next_probability_dict(
double_recombination_required = False
single_recombination_required = False

template_is_hom = np.logical_not(template_is_het)
query_is_hom = np.logical_not(query_is_het)

EQUAL_BOTH_HOM = np.logical_and(
np.logical_and(is_match, template_is_hom), query_is_hom
)
UNEQUAL_BOTH_HOM = np.logical_and(
np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom
)
BOTH_HET = np.logical_and(template_is_het, query_is_het)
REF_HOM_OBS_HET = np.logical_and(template_is_hom, query_is_het)
REF_HET_OBS_HOM = np.logical_and(template_is_het, query_is_hom)

p_e = (
EQUAL_BOTH_HOM * (1 - mu) ** 2
+ UNEQUAL_BOTH_HOM * (mu**2)
+ REF_HOM_OBS_HET * (2 * mu * (1 - mu))
+ REF_HET_OBS_HOM * (mu * (1 - mu))
+ BOTH_HET * ((1 - mu) ** 2 + mu**2)
)
if query_is_missing:
p_e = 1
else:
template_is_hom = np.logical_not(template_is_het)
query_is_hom = np.logical_not(query_is_het)
EQUAL_BOTH_HOM = np.logical_and(
np.logical_and(is_match, template_is_hom), query_is_hom
)
UNEQUAL_BOTH_HOM = np.logical_and(
np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom
)
BOTH_HET = np.logical_and(template_is_het, query_is_het)
REF_HOM_OBS_HET = np.logical_and(template_is_hom, query_is_het)
REF_HET_OBS_HOM = np.logical_and(template_is_het, query_is_hom)

p_e = (
EQUAL_BOTH_HOM * (1 - mu) ** 2
+ UNEQUAL_BOTH_HOM * (mu**2)
+ REF_HOM_OBS_HET * (2 * mu * (1 - mu))
+ REF_HET_OBS_HOM * (mu * (1 - mu))
+ BOTH_HET * ((1 - mu) ** 2 + mu**2)
)

no_switch = (1 - r) ** 2 + 2 * (r_n * (1 - r)) + r_n**2
single_switch = r_n * (1 - r) + r_n**2
Expand Down Expand Up @@ -919,18 +927,33 @@ def example_genotypes(self, ts):
s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0])
H = H[:, 2:]

genotypes = [
s,
H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]),
]

s_tmp = s.copy()
s_tmp[0, -1] = MISSING
genotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, ts.num_sites // 2] = MISSING
genotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, :] = MISSING
genotypes.append(s_tmp)

m = ts.get_num_sites()
n = H.shape[1]

G = np.zeros((m, n, n))
for i in range(m):
G[i, :, :] = np.add.outer(H[i, :], H[i, :])

return H, G, s
return H, G, genotypes

def example_parameters_genotypes(self, ts, seed=42):
np.random.seed(seed)
H, G, s = self.example_genotypes(ts)
H, G, genotypes = self.example_genotypes(ts)
n = H.shape[1]
m = ts.get_num_sites()

Expand All @@ -941,13 +964,16 @@ def example_parameters_genotypes(self, ts, seed=42):

e = self.genotype_emission(mu, m)

yield n, m, G, s, e, r, mu
for s in genotypes:
yield n, m, G, s, e, r, mu

# Mixture of random and extremes
rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)]
mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33]

for r, mu in itertools.product(rs, mus):
e = self.genotype_emission(mu, m)

for s, r, mu in itertools.product(genotypes, rs, mus):
r[0] = 0
e = self.genotype_emission(mu, m)
yield n, m, G, s, e, r, mu
Expand Down

0 comments on commit c548679

Please sign in to comment.