Skip to content

Commit

Permalink
Speed up diploid LS tests by reducing parameter space checked
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 23, 2023
1 parent 124cb51 commit 09ae207
Showing 1 changed file with 32 additions and 28 deletions.
60 changes: 32 additions & 28 deletions python/tests/test_genotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,6 @@ def process_site(
s = self.output.normalisation_factor[site.id]
for st1 in self.T:
if st1.tree_node != tskit.NULL:

for st2 in st1.value_list:
st2.value = (
((self.rho[site.id] / self.ts.num_samples) ** 2)
Expand Down Expand Up @@ -1198,7 +1197,6 @@ def genotype_emission(self, mu, m):
return e

def example_genotypes(self, ts):

H = ts.genotype_matrix()
s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0])
H = H[:, 2:]
Expand Down Expand Up @@ -1247,9 +1245,8 @@ def example_parameters_genotypes(self, ts, seed=42):
rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)]
mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33]

e = self.genotype_emission(mu, m)

for s, r, mu in itertools.product(genotypes, rs, mus):
s = genotypes[0]
for r, mu in itertools.product(rs, mus):
r[0] = 0
e = self.genotype_emission(mu, m)
yield n, m, G, s, e, r, mu
Expand All @@ -1267,36 +1264,44 @@ def test_simple_n_10_no_recombination(self):
assert ts.num_sites > 3
self.verify(ts)

def test_simple_n_10_no_recombination_high_mut(self):
ts = msprime.simulate(10, recombination_rate=0, mutation_rate=3, random_seed=42)
assert ts.num_sites > 3
self.verify(ts)

def test_simple_n_10_no_recombination_higher_mut(self):
ts = msprime.simulate(20, recombination_rate=0, mutation_rate=3, random_seed=42)
assert ts.num_sites > 3
self.verify(ts)

def test_simple_n_6(self):
ts = msprime.simulate(6, recombination_rate=2, mutation_rate=7, random_seed=42)
assert ts.num_sites > 5
self.verify(ts)

def test_simple_n_8(self):
ts = msprime.simulate(8, recombination_rate=2, mutation_rate=5, random_seed=42)
assert ts.num_sites > 5
self.verify(ts)

def test_simple_n_8_high_recombination(self):
ts = msprime.simulate(8, recombination_rate=20, mutation_rate=5, random_seed=42)
assert ts.num_trees > 15
assert ts.num_sites > 5
self.verify(ts)

def test_simple_n_16(self):
ts = msprime.simulate(16, recombination_rate=2, mutation_rate=5, random_seed=42)
assert ts.num_sites > 5
self.verify(ts)
# FIXME Reducing the number of test cases here as they take a long time to run,
# and we will want to refactor the test infrastructure when implementing these
# diploid methods in the library.

# def test_simple_n_10_no_recombination_high_mut(self):
# ts = msprime.simulate(
# 10, recombination_rate=0, mutation_rate=3, random_seed=42)
# assert ts.num_sites > 3
# self.verify(ts)

# def test_simple_n_10_no_recombination_higher_mut(self):
# ts = msprime.simulate(
# 20, recombination_rate=0, mutation_rate=3, random_seed=42)
# assert ts.num_sites > 3
# self.verify(ts)

# def test_simple_n_8(self):
# ts = msprime.simulate(
# 8, recombination_rate=2, mutation_rate=5, random_seed=42)
# assert ts.num_sites > 5
# self.verify(ts)

# def test_simple_n_16(self):
# ts = msprime.simulate(
# 16, recombination_rate=2, mutation_rate=5, random_seed=42)
# assert ts.num_sites > 5
# self.verify(ts)

def verify(self, ts):
raise NotImplementedError()
Expand Down Expand Up @@ -1436,7 +1441,6 @@ class TestTreeViterbiDip(VitAlgorithmBase):
"""

def verify(self, ts):

for n, m, _, s, _, r, mu in self.example_parameters_genotypes(ts):
# Note, need to remove the first sample from the ts, and ensure that
# invariant sites aren't removed.
Expand All @@ -1450,14 +1454,14 @@ def verify(self, ts):
)
ts_check = ts.simplify(range(1, n + 1), filter_sites=False)
phased_path, ll = ls.viterbi(
G_check, s, r, mutation_rate=mu, scale_mutation_based_on_n_alleles=False
G_check, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
)
path_ll_matrix = ls.path_ll(
G_check,
s,
phased_path,
r,
mutation_rate=mu,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
)

Expand All @@ -1472,7 +1476,7 @@ def verify(self, ts):
s,
np.transpose(path_tree_dict),
r,
mutation_rate=mu,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
)

Expand Down

0 comments on commit 09ae207

Please sign in to comment.