diff --git a/python/tests/test_genotype_matching_fb.py b/python/tests/test_genotype_matching_fb.py index 8f61f12bfc..bc39e75633 100644 --- a/python/tests/test_genotype_matching_fb.py +++ b/python/tests/test_genotype_matching_fb.py @@ -16,6 +16,7 @@ MISSING = -1 + def mirror_coordinates(ts): """ Returns a copy of the specified tree sequence in which all @@ -249,6 +250,8 @@ def stupid_compress_dict(self): u = tree.parent(u) self.N[self.T_index[u]] += 1 + print(self.T) + def update_tree(self): """ Update the internal data structures to move on to the next tree. @@ -343,7 +346,7 @@ def update_tree(self): # Also need to mark the corresponding InternalValueTransition as # unused for the remaining states - + for st2 in T: if not (st2.value_list == tskit.NULL): st2.value_list[T_index[vt.tree_node]].value = -1 @@ -452,7 +455,7 @@ def update_probabilities(self, site, genotype_state): match, template_is_het, query_is_het, - query_is_missing + query_is_missing, ) # This will ensure that allelic_state[:n] is filled @@ -570,7 +573,14 @@ def compute_normalisation_factor_dict(self): raise NotImplementedError() def compute_next_probability_dict( - self, site_id, p_last, inner_summation, is_match, template_is_het, query_is_het, query_is_missing + self, + site_id, + p_last, + inner_summation, + is_match, + template_is_het, + query_is_het, + query_is_missing, ): raise NotImplementedError() @@ -679,7 +689,7 @@ def compute_next_probability_dict( is_match, template_is_het, query_is_het, - query_is_missing + query_is_missing, ): rho = self.rho[site_id] mu = self.mu[site_id] @@ -691,12 +701,11 @@ def compute_next_probability_dict( + (1 - rho) ** 2 * p_last ) - template_is_hom = np.logical_not(template_is_het) - if query_is_missing: p_e = 1 else: query_is_hom = np.logical_not(query_is_het) + template_is_hom = np.logical_not(template_is_het) EQUAL_BOTH_HOM = np.logical_and( np.logical_and(is_match, template_is_hom), query_is_hom @@ -710,10 +719,10 @@ def compute_next_probability_dict( p_e = ( EQUAL_BOTH_HOM * (1 - mu) ** 2 - + UNEQUAL_BOTH_HOM * (mu ** 2) + + UNEQUAL_BOTH_HOM * (mu**2) + REF_HOM_OBS_HET * (2 * mu * (1 - mu)) + REF_HET_OBS_HOM * (mu * (1 - mu)) - + BOTH_HET * ((1 - mu) ** 2 + mu ** 2) + + BOTH_HET * ((1 - mu) ** 2 + mu**2) ) return p_t * p_e @@ -751,7 +760,7 @@ def compute_next_probability_dict( is_match, template_is_het, query_is_het, - query_is_missing + query_is_missing, ): mu = self.mu[site_id] @@ -774,12 +783,12 @@ def compute_next_probability_dict( p_e = ( EQUAL_BOTH_HOM * (1 - mu) ** 2 - + UNEQUAL_BOTH_HOM * (mu ** 2) + + UNEQUAL_BOTH_HOM * (mu**2) + REF_HOM_OBS_HET * (2 * mu * (1 - mu)) + REF_HET_OBS_HOM * (mu * (1 - mu)) - + BOTH_HET * ((1 - mu) ** 2 + mu ** 2) + + BOTH_HET * ((1 - mu) ** 2 + mu**2) ) - + return p_next * p_e @@ -804,14 +813,13 @@ def genotype_emission(self, mu, m): # Define the emission probability matrix e = np.zeros((m, 8)) e[:, EQUAL_BOTH_HOM] = (1 - mu) ** 2 - e[:, UNEQUAL_BOTH_HOM] = mu ** 2 - e[:, BOTH_HET] = (1 - mu) ** 2 + mu ** 2 + e[:, UNEQUAL_BOTH_HOM] = mu**2 + e[:, BOTH_HET] = (1 - mu) ** 2 + mu**2 e[:, REF_HOM_OBS_HET] = 2 * mu * (1 - mu) e[:, REF_HET_OBS_HOM] = mu * (1 - mu) return e - def example_genotypes(self, ts): H = ts.genotype_matrix() @@ -840,10 +848,8 @@ def example_genotypes(self, ts): for i in range(m): G[i, :, :] = np.add.outer(H[i, :], H[i, :]) - return H, G, genotypes - def example_parameters_genotypes(self, ts, seed=42): np.random.seed(seed) H, G, genotypes = self.example_genotypes(ts) @@ -871,7 +877,6 @@ def example_parameters_genotypes(self, ts, seed=42): e = self.genotype_emission(mu, m) yield n, m, G, s, e, r, mu - def assertAllClose(self, A, B): """Assert that all entries of two matrices are 'close'""" assert np.allclose(A, B, rtol=1e-5, atol=1e-8) @@ -924,6 +929,7 @@ def verify(self, ts): class FBAlgorithmBase(LSBase): """Base for forwards backwards algorithm tests.""" + class TestMirroringDipdict(FBAlgorithmBase): """Tests that mirroring the tree sequence and running forwards and backwards algorithms give the same log-likelihood of observing the data.""" @@ -952,7 +958,7 @@ def verify(self, ts): ll_mirror_tree_dict = np.sum(np.log10(cm_mirror.normalisation_factor)) self.assertAllClose(ll_tree, ll_mirror_tree_dict) - + # Ensure that the decoded matrices are the same F_mirror_matrix, c, ll = ls.forwards( np.flip(G_check, axis=0), @@ -991,6 +997,7 @@ def verify(self, ts): ll_tree = np.sum(np.log10(cm_d.normalisation_factor)) self.assertAllClose(ll, ll_tree) + class TestForwardBackwardTree(FBAlgorithmBase): """Tests that the tree algorithm computes the same forward matrix as the simple method.""" @@ -1022,7 +1029,7 @@ def verify(self, ts): # Note, need to remove the first sample from the ts, and ensure that # invariant sites aren't removed. - + ts_check = ts.simplify(range(1, n + 1), filter_sites=False) c_f = ls_forward_tree(s[0, :], ts_check, r, mu) ll_tree = np.sum(np.log10(c_f.normalisation_factor)) diff --git a/python/tests/test_genotype_matching_viterbi.py b/python/tests/test_genotype_matching_viterbi.py index 89377bdb33..03fa0953a8 100644 --- a/python/tests/test_genotype_matching_viterbi.py +++ b/python/tests/test_genotype_matching_viterbi.py @@ -13,6 +13,8 @@ REF_HOM_OBS_HET = 1 REF_HET_OBS_HOM = 2 +MISSING = -1 + class ValueTransition: """Simple struct holding value transition values.""" @@ -390,6 +392,7 @@ def update_probabilities(self, site, genotype_state): ] query_is_het = genotype_state == 1 + query_is_missing = genotype_state == MISSING for st1 in T: u1 = st1.tree_node @@ -423,6 +426,7 @@ def update_probabilities(self, site, genotype_state): match, template_is_het, query_is_het, + query_is_missing, u1, u2, ) @@ -486,6 +490,7 @@ def compute_next_probability_dict( is_match, template_is_het, query_is_het, + query_is_missing, node_1, node_2, ): @@ -830,6 +835,7 @@ def compute_next_probability_dict( is_match, template_is_het, query_is_het, + query_is_missing, node_1, node_2, ): @@ -841,26 +847,28 @@ def compute_next_probability_dict( double_recombination_required = False single_recombination_required = False - template_is_hom = np.logical_not(template_is_het) - query_is_hom = np.logical_not(query_is_het) - - EQUAL_BOTH_HOM = np.logical_and( - np.logical_and(is_match, template_is_hom), query_is_hom - ) - UNEQUAL_BOTH_HOM = np.logical_and( - np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom - ) - BOTH_HET = np.logical_and(template_is_het, query_is_het) - REF_HOM_OBS_HET = np.logical_and(template_is_hom, query_is_het) - REF_HET_OBS_HOM = np.logical_and(template_is_het, query_is_hom) - - p_e = ( - EQUAL_BOTH_HOM * (1 - mu) ** 2 - + UNEQUAL_BOTH_HOM * (mu**2) - + REF_HOM_OBS_HET * (2 * mu * (1 - mu)) - + REF_HET_OBS_HOM * (mu * (1 - mu)) - + BOTH_HET * ((1 - mu) ** 2 + mu**2) - ) + if query_is_missing: + p_e = 1 + else: + template_is_hom = np.logical_not(template_is_het) + query_is_hom = np.logical_not(query_is_het) + EQUAL_BOTH_HOM = np.logical_and( + np.logical_and(is_match, template_is_hom), query_is_hom + ) + UNEQUAL_BOTH_HOM = np.logical_and( + np.logical_and(np.logical_not(is_match), template_is_hom), query_is_hom + ) + BOTH_HET = np.logical_and(template_is_het, query_is_het) + REF_HOM_OBS_HET = np.logical_and(template_is_hom, query_is_het) + REF_HET_OBS_HOM = np.logical_and(template_is_het, query_is_hom) + + p_e = ( + EQUAL_BOTH_HOM * (1 - mu) ** 2 + + UNEQUAL_BOTH_HOM * (mu**2) + + REF_HOM_OBS_HET * (2 * mu * (1 - mu)) + + REF_HET_OBS_HOM * (mu * (1 - mu)) + + BOTH_HET * ((1 - mu) ** 2 + mu**2) + ) no_switch = (1 - r) ** 2 + 2 * (r_n * (1 - r)) + r_n**2 single_switch = r_n * (1 - r) + r_n**2 @@ -919,6 +927,21 @@ def example_genotypes(self, ts): s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0]) H = H[:, 2:] + genotypes = [ + s, + H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]), + ] + + s_tmp = s.copy() + s_tmp[0, -1] = MISSING + genotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, ts.num_sites // 2] = MISSING + genotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, :] = MISSING + genotypes.append(s_tmp) + m = ts.get_num_sites() n = H.shape[1] @@ -926,11 +949,11 @@ def example_genotypes(self, ts): for i in range(m): G[i, :, :] = np.add.outer(H[i, :], H[i, :]) - return H, G, s + return H, G, genotypes def example_parameters_genotypes(self, ts, seed=42): np.random.seed(seed) - H, G, s = self.example_genotypes(ts) + H, G, genotypes = self.example_genotypes(ts) n = H.shape[1] m = ts.get_num_sites() @@ -941,13 +964,16 @@ def example_parameters_genotypes(self, ts, seed=42): e = self.genotype_emission(mu, m) - yield n, m, G, s, e, r, mu + for s in genotypes: + yield n, m, G, s, e, r, mu # Mixture of random and extremes rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] - for r, mu in itertools.product(rs, mus): + e = self.genotype_emission(mu, m) + + for s, r, mu in itertools.product(genotypes, rs, mus): r[0] = 0 e = self.genotype_emission(mu, m) yield n, m, G, s, e, r, mu