From 1680f7024442e8260e7c5449f5adb2ded147fb16 Mon Sep 17 00:00:00 2001 From: August <30163079+augeorge@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:02:05 -0700 Subject: [PATCH] added tests for optimize and calibration --- tests/test_interfaces.py | 57 +++++++++++----------------------------- 1 file changed, 16 insertions(+), 41 deletions(-) diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 8fbe2ea2..b29de8cf 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -87,43 +87,8 @@ def setup_calibrate(model_fixture, start_time, end_time, logging_step_size): "num_iterations": 2, } -RTOL = [1e-6, 1e-5, 1e-4] -ATOL = [1e-8, 1e-7, 1e-6] - -@pytest.mark.parametrize("sample_method", SAMPLE_METHODS) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("start_time", START_TIMES) -@pytest.mark.parametrize("end_time", END_TIMES) -@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) -@pytest.mark.parametrize("num_samples", NUM_SAMPLES) -def test_sample_no_interventions( - sample_method, model, start_time, end_time, logging_step_size, num_samples -): - model_url = model.url - - with pyro.poutine.seed(rng_seed=0): - result1 = sample_method( - model_url, end_time, logging_step_size, num_samples, start_time=start_time - )["unprocessed_result"] - with pyro.poutine.seed(rng_seed=0): - result2 = sample_method( - model_url, end_time, logging_step_size, num_samples, start_time=start_time - )["unprocessed_result"] - - result3 = sample_method( - model_url, end_time, logging_step_size, num_samples, start_time=start_time - )["unprocessed_result"] - - for result in [result1, result2, result3]: - assert isinstance(result, dict) - check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) - - check_states_match(result1, result2) - if model.has_distributional_parameters: - check_states_match_in_all_but_values(result1, result3) - - if sample_method.__name__ == "dummy_ensemble_sample": - assert "total_state" in result1.keys() +RTOL = [1e-6, 1e-4] +ATOL = [1e-8, 1e-6] @pytest.mark.parametrize("sample_method", SAMPLE_METHODS) @@ -134,7 +99,7 @@ def test_sample_no_interventions( @pytest.mark.parametrize("num_samples", NUM_SAMPLES) @pytest.mark.parametrize("rtol", RTOL) @pytest.mark.parametrize("atol", ATOL) -def test_sample_with_tolerance( +def test_sample_no_interventions( sample_method, model, start_time, end_time, logging_step_size, num_samples, rtol, atol ): model_url = model.url @@ -404,8 +369,10 @@ def test_calibrate_no_kwargs( @pytest.mark.parametrize("start_time", START_TIMES) @pytest.mark.parametrize("end_time", END_TIMES) @pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) +@pytest.mark.parametrize("rtol", RTOL) +@pytest.mark.parametrize("atol", ATOL) def test_calibrate_deterministic( - model_fixture, start_time, end_time, logging_step_size + model_fixture, start_time, end_time, logging_step_size, rtol, atol ): model_url = model_fixture.url ( @@ -421,6 +388,8 @@ def test_calibrate_deterministic( "data_mapping": model_fixture.data_mapping, "start_time": start_time, "deterministic_learnable_parameters": deterministic_learnable_parameters, + "rtol": rtol, + "atol": atol, **CALIBRATE_KWARGS, } @@ -440,7 +409,7 @@ def test_calibrate_deterministic( assert torch.allclose(param_value, param_sample_2[param_name]) result = sample( - *sample_args, **sample_kwargs, inferred_parameters=inferred_parameters + *sample_args, **sample_kwargs, inferred_parameters=inferred_parameters, rtol=rtol, atol=atol )["unprocessed_result"] check_result_sizes(result, start_time, end_time, logging_step_size, 1) @@ -603,7 +572,9 @@ def test_output_format( @pytest.mark.parametrize("start_time", START_TIMES) @pytest.mark.parametrize("end_time", END_TIMES) @pytest.mark.parametrize("num_samples", NUM_SAMPLES) -def test_optimize(model_fixture, start_time, end_time, num_samples): +@pytest.mark.parametrize("rtol", RTOL) +@pytest.mark.parametrize("atol", ATOL) +def test_optimize(model_fixture, start_time, end_time, num_samples, rtol, atol): logging_step_size = 1.0 model_url = model_fixture.url @@ -627,6 +598,8 @@ def __call__(self, x): "maxiter": 1, "maxfeval": 2, "progress_hook": progress_hook, + "rtol": rtol, + "atol": atol } bounds_interventions = optimize_kwargs["bounds_interventions"] opt_result = optimize( @@ -665,6 +638,8 @@ def __call__(self, x): static_parameter_interventions=opt_intervention, solver_method=optimize_kwargs["solver_method"], solver_options=optimize_kwargs["solver_options"], + rtol=rtol, + atol=atol )["unprocessed_result"] intervened_result_subset = {