Skip to content

Commit

Permalink
Merge pull request #127 from szhan/refactor_get_ts_simple
Browse files Browse the repository at this point in the history
Refactor functions to simulate test data
  • Loading branch information
szhan authored Jun 21, 2024
2 parents a530b54 + 0f613ad commit e518bf6
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 375 deletions.
81 changes: 42 additions & 39 deletions tests/lsbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,62 +196,65 @@ def get_examples_pars(

# Prepare simple example datasets.
def get_ts_simple_n10_no_recomb(self, seed=42):
ts = msprime.simulate(
10,
recombination_rate=0,
mutation_rate=0.5,
ts = msprime.sim_mutations(
msprime.sim_ancestry(
samples=10,
ploidy=1,
sequence_length=10,
recombination_rate=0.0,
random_seed=seed,
),
model=msprime.BinaryMutationModel(),
rate=0.5,
random_seed=seed,
)
assert ts.num_sites > 3
return ts

def get_ts_simple_n6(self, seed=42):
ts = msprime.simulate(
6,
recombination_rate=2,
mutation_rate=7,
random_seed=seed,
)
assert ts.num_sites > 5
return ts

def get_ts_simple_n8(self, seed=42):
ts = msprime.simulate(
8,
recombination_rate=2,
mutation_rate=5,
def get_ts_simple(self, num_samples, seed=42):
ts = msprime.sim_mutations(
msprime.sim_ancestry(
samples=num_samples,
ploidy=1,
sequence_length=10,
recombination_rate=2.0,
random_seed=seed,
),
rate=5.0,
model=msprime.BinaryMutationModel(),
random_seed=seed,
)
assert ts.num_sites > 5
return ts

def get_ts_simple_n8_high_recomb(self, seed=42):
ts = msprime.simulate(
8,
recombination_rate=20,
mutation_rate=5,
ts = msprime.sim_mutations(
msprime.sim_ancestry(
samples=8,
ploidy=1,
sequence_length=20,
recombination_rate=20.0,
random_seed=seed,
),
rate=5.0,
model=msprime.BinaryMutationModel(),
random_seed=seed,
)
assert ts.num_trees > 15
assert ts.num_sites > 5
return ts

def get_ts_simple_n16(self, seed=42):
ts = msprime.simulate(
16,
recombination_rate=2,
mutation_rate=5,
random_seed=seed,
)
assert ts.num_sites > 5
return ts

def get_ts_custom_pars(self, ref_panel_size, length, mean_r, mean_mu, seed=42):
ts = msprime.simulate(
ref_panel_size,
length=length,
recombination_rate=mean_r,
mutation_rate=mean_mu,
def get_ts_custom_pars(self, num_samples, seq_length, mean_r, mean_mu, seed=42):
ts = msprime.sim_mutations(
msprime.sim_ancestry(
samples=num_samples,
ploidy=1,
sequence_length=seq_length,
recombination_rate=mean_r,
random_seed=seed,
),
rate=mean_mu,
model=msprime.BinaryMutationModel(),
random_seed=seed,
)
return ts
Expand Down
51 changes: 8 additions & 43 deletions tests/test_api_fb_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,60 +64,25 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("num_samples", [4, 8, 16])
@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
def test_ts_simple(self, num_samples, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple(num_samples)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
num_samples=30, seq_length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(ts, scale_mutation_rate, include_ancestors)
51 changes: 8 additions & 43 deletions tests/test_api_fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,60 +59,25 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("num_samples", [4, 8, 16, 32])
@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
def test_ts_simple(self, num_samples, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple(num_samples)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
num_samples=45, seq_length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(ts, scale_mutation_rate, include_ancestors)
51 changes: 8 additions & 43 deletions tests/test_api_vit_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,60 +45,25 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("num_samples", [4, 8, 16])
@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
def test_ts_simple(self, num_samples, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple(num_samples)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
num_samples=30, seq_length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(ts, scale_mutation_rate, include_ancestors)
51 changes: 8 additions & 43 deletions tests/test_api_vit_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,60 +40,25 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("num_samples", [4, 8, 16, 32])
@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
def test_ts_simple(self, num_samples, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple(num_samples)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
num_samples=45, seq_length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(ts, scale_mutation_rate, include_ancestors)
Loading

0 comments on commit e518bf6

Please sign in to comment.