From f5c5bec731441377b845a3b702f01927113bf462 Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Thu, 8 Aug 2024 11:10:53 -0400 Subject: [PATCH] Allow for interventions on constant parameters (#597) * making failing test for intervention on constant param * linting * add deterministic parameters to the trace --------- Co-authored-by: sabinala --- pyciemss/compiled_dynamics.py | 6 ++++-- tests/fixtures.py | 1 + tests/test_interfaces.py | 36 +++++++++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/pyciemss/compiled_dynamics.py b/pyciemss/compiled_dynamics.py index 2988c5feb..85248440b 100644 --- a/pyciemss/compiled_dynamics.py +++ b/pyciemss/compiled_dynamics.py @@ -101,13 +101,15 @@ def observables(self, X: State[torch.Tensor]) -> State[torch.Tensor]: def instantiate_parameters(self): # Initialize random parameters once before simulating. # This is necessary because the parameters are PyroSample objects. - for k in _compile_param_values(self.src).keys(): + for k, param in _compile_param_values(self.src).items(): param_name = get_name(k) # Separating the persistent parameters from the non-persistent ones # is necessary because the persistent parameters are PyroSample objects representing the distribution, # and should not be modified during intervention. param_val = getattr(self, f"persistent_{param_name}") - self.register_buffer(get_name(k), param_val) + if isinstance(param, torch.Tensor): + pyro.deterministic(f"persistent_{param_name}", param_val) + self.register_buffer(param_name, param_val) def forward( self, diff --git a/tests/fixtures.py b/tests/fixtures.py index 198d161cb..027d19557 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -193,6 +193,7 @@ def __init__( LOGGING_STEP_SIZES = [5.0] NUM_SAMPLES = [2] +SEIRHD_NPI_STATIC_PARAM_INTERV = [{torch.tensor(10.0): {"delta": torch.tensor(0.2)}}] NON_POS_INTS = [ 3.5, -3, diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 7e647c5b3..fea96e2a4 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -30,6 +30,7 @@ NON_POS_INTS, NUM_SAMPLES, OPT_MODELS, + SEIRHD_NPI_STATIC_PARAM_INTERV, START_TIMES, check_result_sizes, check_states_match, @@ -737,3 +738,38 @@ def test_errors_for_bad_amrs( logging_step_size, num_samples, ) + + +@pytest.mark.parametrize("sample_method", [sample]) +@pytest.mark.parametrize("model_fixture", MODELS) +@pytest.mark.parametrize("end_time", END_TIMES) +@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) +@pytest.mark.parametrize("num_samples", NUM_SAMPLES) +@pytest.mark.parametrize("start_time", START_TIMES) +@pytest.mark.parametrize("seirhd_npi_intervention", SEIRHD_NPI_STATIC_PARAM_INTERV) +def test_intervention_on_constant_param( + sample_method, + model_fixture, + end_time, + logging_step_size, + num_samples, + start_time, + seirhd_npi_intervention, +): + # Assert that sample returns expected result with intervention on constant parameter + if "SEIRHD_NPI" not in model_fixture.url: + pytest.skip("Only test 'SEIRHD_NPI' models with constant parameter delta") + else: + processed_result = sample_method( + model_fixture.url, + end_time, + logging_step_size, + num_samples, + start_time=start_time, + static_parameter_interventions=seirhd_npi_intervention, + )["data"] + assert isinstance(processed_result, pd.DataFrame) + assert processed_result.shape[0] == num_samples * len( + torch.arange(start_time, end_time + logging_step_size, logging_step_size) + ) + assert processed_result.shape[1] >= 2