Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 15, 2023
1 parent 82e97eb commit ea2c822
Showing 1 changed file with 36 additions and 19 deletions.
55 changes: 36 additions & 19 deletions python/tests/test_haplotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,19 +1022,19 @@ def verify(self, ts):
# TODO add params to run the various checks
def check_viterbi(ts, h, recombination=None, mutation=None):
h = np.array(h).astype(np.int8)
n = ts.num_samples
assert len(h) == ts.num_sites
m = ts.num_sites
assert len(h) == m
if recombination is None:
recombination = np.zeros(ts.num_sites) + 1e-9
if mutation is None:
mutation = np.zeros(ts.num_sites)
precision = 22

H = ts.genotype_matrix().T
G = ts.genotype_matrix()

path, ll = ls.viterbi(
H,
h.reshape(1, n),
G,
h.reshape(1, m),
recombination,
mutation_rate=mutation,
scale_mutation_based_on_n_alleles=False,
Expand All @@ -1050,8 +1050,8 @@ def check_viterbi(ts, h, recombination=None, mutation=None):
# the same as ll_tree (and ll).
path_tree = cm.traceback()
ll_check = ls.path_ll(
H,
h.reshape(1, n),
G,
h.reshape(1, m),
path_tree,
recombination,
mutation_rate=mutation,
Expand Down Expand Up @@ -1079,16 +1079,16 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None):
h = np.array(h).astype(np.int8)
n = ts.num_samples
m = ts.num_sites
assert len(h) == ts.num_sites
assert len(h) == m
if recombination is None:
recombination = np.zeros(ts.num_sites) + 1e-9
if mutation is None:
mutation = np.zeros(ts.num_sites)

H = ts.genotype_matrix().T
G = ts.genotype_matrix()
F, c, ll = ls.forwards(
H,
h.reshape(1, n),
G,
h.reshape(1, m),
recombination,
mutation_rate=mutation,
scale_mutation_based_on_n_alleles=False,
Expand Down Expand Up @@ -1120,17 +1120,17 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None):
def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None):
precision = 22
h = np.array(h).astype(np.int8)
n = ts.num_samples
assert len(h) == ts.num_sites
m = ts.num_sites
assert len(h) == m
if recombination is None:
recombination = np.zeros(ts.num_sites) + 1e-9
if mutation is None:
mutation = np.zeros(ts.num_sites)

H = ts.genotype_matrix().T
G = ts.genotype_matrix()
B = ls.backwards(
H,
h.reshape(1, n),
G,
h.reshape(1, m),
forward_cm.normalisation_factor,
recombination,
mutation_rate=mutation,
Expand All @@ -1148,8 +1148,6 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None):
nt.assert_array_equal(
backward_cm.normalisation_factor, forward_cm.normalisation_factor
)
B_tree = backward_cm.decode()
nt.assert_array_almost_equal(B, B_tree)

ll_ts = ts._ll_tree_sequence
ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision)
Expand All @@ -1170,11 +1168,17 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None):
for node in py_site.keys():
assert py_site[node] == lib_site[node]

# Something weird is happening here - why don't these agree?

nt.assert_array_equal(cm_lib.normalisation_factor, forward_cm.normalisation_factor)
B_lib = cm_lib.decode()
B_tree = backward_cm.decode()
# print(B_tree)
# print(B_lib)
# print(B)
nt.assert_array_almost_equal(B_tree, B_lib)
nt.assert_array_almost_equal(B, B_lib)
# print(B_lib)
# print(B)


def add_unique_sample_mutations(ts, start=0):
Expand Down Expand Up @@ -1258,3 +1262,16 @@ def test_switch_each_sample_missing_middle(self):
nt.assert_array_equal([0, 3, 3, 3], path)
cm = check_forward_matrix(ts, h)
check_backward_matrix(ts, h, cm)


class TestSimulationExamples:
@pytest.mark.parametrize("n", [5, 10, 50])
@pytest.mark.parametrize("L", [1, 10, 100])
def test_continuous_genome(self, n, L):
ts = msprime.simulate(
n, length=L, recombination_rate=1, mutation_rate=1, random_seed=42
)
h = np.zeros(ts.num_sites, dtype=np.int8)
check_viterbi(ts, h)
cm = check_forward_matrix(ts, h)
check_backward_matrix(ts, h, cm)

0 comments on commit ea2c822

Please sign in to comment.