Skip to content

Commit 6ecd5ff

Browse files
committed
Fix unit testing to work with msprime > 0.7.4.
Closes #518.
1 parent c19fd69 commit 6ecd5ff

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

stdpopsim/models.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,16 @@ def verify_demographic_events_equal(
127127
to the specified tolerances and raises a UnequalModelsError otherwise.
128128
"""
129129
# Get the low-level dictionary representations of the events.
130-
dicts1 = [event.get_ll_representation(num_populations) for event in events1]
131-
dicts2 = [event.get_ll_representation(num_populations) for event in events2]
130+
# XXX: Msprime introduced a breaking change between 0.7.4 and 1.0, removing
131+
# the num_populations parameter to get_ll_representation(). See #518.
132+
# When we depend on msprime 1.0, this should be changed to instead use
133+
# msprime.DemographicEvent.asdict().
134+
from inspect import signature
135+
ll_args = ()
136+
if len(signature(msprime.MassMigration.get_ll_representation).parameters) == 2:
137+
ll_args = (num_populations, )
138+
dicts1 = [event.get_ll_representation(*ll_args) for event in events1]
139+
dicts2 = [event.get_ll_representation(*ll_args) for event in events2]
132140
if len(dicts1) != len(dicts2):
133141
raise UnequalModelsError("Different numbers of demographic events")
134142
for d1, d2 in zip(dicts1, dicts2):

tests/test_models.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Tests for simulation model infrastructure.
33
"""
44
import unittest
5+
from unittest import mock
6+
import inspect
57
import itertools
68
import io
79
import sys
@@ -222,7 +224,7 @@ def test_different_types(self):
222224
msprime.PopulationParametersChange(time=1, initial_size=1),
223225
msprime.MigrationRateChange(time=1, rate=1),
224226
msprime.MassMigration(time=1, source=1),
225-
msprime.SimpleBottleneck(time=1)]
227+
msprime.SimpleBottleneck(time=1, population=0)]
226228
for a, b in itertools.combinations(events, 2):
227229
self.assertFalse(models.demographic_events_equal([a], [b], 1))
228230
self.assertFalse(models.demographic_events_equal([b], [a], 1))
@@ -322,6 +324,47 @@ def f(time=1, population=1, proportion=1):
322324
models.verify_demographic_events_equal([a], [b], 1)
323325

324326

327+
# Test compatibility shim for get_ll_representation change in msprime. See #518.
328+
# TODO: remove this when we depend on msprime 1.0.
329+
def wrap_get_ll(get_ll_func):
330+
num_params = len(inspect.signature(
331+
msprime.MassMigration.get_ll_representation).parameters)
332+
if num_params == 2:
333+
# Wrapper for msprime <= 0.7.4 to check for msprime >= 1.0 compat.
334+
def new_func(self):
335+
return get_ll_func(self, num_populations=1)
336+
elif num_params == 1:
337+
# Wrapper for msprime >= 1.0 to check for msprime <= 0.7.4 compat.
338+
def new_func(self, num_populations=1):
339+
return get_ll_func(self)
340+
else:
341+
raise RuntimeError(
342+
"Unexpected signature for msprime.MassMigration.get_ll_representation")
343+
return new_func
344+
345+
346+
@mock.patch(
347+
"msprime.PopulationParametersChange.get_ll_representation",
348+
wrap_get_ll(msprime.PopulationParametersChange.get_ll_representation))
349+
@mock.patch(
350+
"msprime.MigrationRateChange.get_ll_representation",
351+
wrap_get_ll(msprime.MigrationRateChange.get_ll_representation))
352+
@mock.patch(
353+
"msprime.MassMigration.get_ll_representation",
354+
wrap_get_ll(msprime.MassMigration.get_ll_representation))
355+
@mock.patch(
356+
"msprime.SimpleBottleneck.get_ll_representation",
357+
wrap_get_ll(msprime.SimpleBottleneck.get_ll_representation))
358+
@mock.patch(
359+
"msprime.InstantaneousBottleneck.get_ll_representation",
360+
wrap_get_ll(msprime.InstantaneousBottleneck.get_ll_representation))
361+
@mock.patch(
362+
"msprime.CensusEvent.get_ll_representation",
363+
wrap_get_ll(msprime.CensusEvent.get_ll_representation))
364+
class TestDemographicEventsEqualMsprimeGetLLCompat(TestDemographicEventsEqual):
365+
pass
366+
367+
325368
class TestRegisterQCModel(unittest.TestCase):
326369

327370
def make_model(self, name):

0 commit comments

Comments
 (0)