Skip to content

Commit

Permalink
num_causal
Browse files Browse the repository at this point in the history
Detect num_causal error in sim_trait function
  • Loading branch information
daikitag authored and mergify[bot] committed Sep 27, 2023
1 parent 8bbf1be commit 5371a97
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tests/test_sim_trait.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def test_bad_input(self, sample_ts, sample_trait_model):
ts=sample_ts, num_causal=num_causal, model=sample_trait_model
)

@pytest.mark.parametrize("num_causal", [0, -1])
def test_num_causal_input(self, sample_ts, sample_trait_model, num_causal):
with pytest.raises(TypeError, match="num_causal must be an integer"):
tstrait.sim_trait(ts=sample_ts, num_causal=1.0, model=sample_trait_model)
with pytest.raises(
ValueError, match="num_causal must be an integer not less than 1"
):
tstrait.sim_trait(
ts=sample_ts, num_causal=num_causal, model=sample_trait_model
)


class TestOutputDim:
"""Check that the sim_trait function gives the correct output regardless of the
Expand Down
2 changes: 2 additions & 0 deletions tstrait/simulate_effect_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tstrait

from .base import _check_instance
from .base import _check_int


class _TraitSimulator:
Expand Down Expand Up @@ -120,6 +121,7 @@ def sim_trait(ts, num_causal, model, random_seed=None):
"""
ts = _check_instance(ts, "ts", tskit.TreeSequence)
model = _check_instance(model, "model", tstrait.TraitModel)
num_causal = _check_int(num_causal, "num_causal", minimum=1)
num_sites = ts.num_sites
if num_sites == 0:
raise ValueError("No mutation in the tree sequence input")
Expand Down

0 comments on commit 5371a97

Please sign in to comment.