diff --git a/tests/fixtures.py b/tests/fixtures.py index 9d2d75802..24354642c 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -182,6 +182,7 @@ def __init__( LOGGING_STEP_SIZES = [5.0] NUM_SAMPLES = [2] +SEIRHD_NPI_STATIC_PARAM_INTERV = [{torch.tensor(10.0): {"delta": torch.tensor(0.2)}}] NON_POS_INTS = [ 3.5, -3, diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 4b58a5710..91b62720f 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -30,6 +30,7 @@ NON_POS_INTS, NUM_SAMPLES, OPT_MODELS, + SEIRHD_NPI_STATIC_PARAM_INTERV, START_TIMES, check_result_sizes, check_states_match, @@ -731,3 +732,38 @@ def test_errors_for_bad_amrs( logging_step_size, num_samples, ) + + +@pytest.mark.parametrize("sample_method", [sample]) +@pytest.mark.parametrize("model_fixture", MODELS) +@pytest.mark.parametrize("end_time", END_TIMES) +@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) +@pytest.mark.parametrize("num_samples", NUM_SAMPLES) +@pytest.mark.parametrize("start_time", START_TIMES) +@pytest.mark.parametrize("seirhd_npi_intervention", SEIRHD_NPI_STATIC_PARAM_INTERV) +def test_intervention_on_constant_param( + sample_method, + model_fixture, + end_time, + logging_step_size, + num_samples, + start_time, + seirhd_npi_intervention, +): + # Assert that sample returns expected result with intervention on constant parameter + if "SEIRHD_NPI" not in model_fixture.url: + pytest.skip("Only test 'SEIRHD_NPI' models with constant parameter delta") + else: + processed_result = sample_method( + model_fixture.url, + end_time, + logging_step_size, + num_samples, + start_time=start_time, + static_parameter_interventions=seirhd_npi_intervention, + )["data"] + assert isinstance(processed_result, pd.DataFrame) + assert processed_result.shape[0] == num_samples * len( + torch.arange(start_time, end_time + logging_step_size, logging_step_size) + ) + assert processed_result.shape[1] >= 2