Skip to content

Commit

Permalink
added tests for optimize and calibration
Browse files Browse the repository at this point in the history
  • Loading branch information
augeorge committed Sep 17, 2024
1 parent 4af3fab commit 1680f70
Showing 1 changed file with 16 additions and 41 deletions.
57 changes: 16 additions & 41 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,43 +87,8 @@ def setup_calibrate(model_fixture, start_time, end_time, logging_step_size):
"num_iterations": 2,
}

RTOL = [1e-6, 1e-5, 1e-4]
ATOL = [1e-8, 1e-7, 1e-6]

@pytest.mark.parametrize("sample_method", SAMPLE_METHODS)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES)
@pytest.mark.parametrize("num_samples", NUM_SAMPLES)
def test_sample_no_interventions(
sample_method, model, start_time, end_time, logging_step_size, num_samples
):
model_url = model.url

with pyro.poutine.seed(rng_seed=0):
result1 = sample_method(
model_url, end_time, logging_step_size, num_samples, start_time=start_time
)["unprocessed_result"]
with pyro.poutine.seed(rng_seed=0):
result2 = sample_method(
model_url, end_time, logging_step_size, num_samples, start_time=start_time
)["unprocessed_result"]

result3 = sample_method(
model_url, end_time, logging_step_size, num_samples, start_time=start_time
)["unprocessed_result"]

for result in [result1, result2, result3]:
assert isinstance(result, dict)
check_result_sizes(result, start_time, end_time, logging_step_size, num_samples)

check_states_match(result1, result2)
if model.has_distributional_parameters:
check_states_match_in_all_but_values(result1, result3)

if sample_method.__name__ == "dummy_ensemble_sample":
assert "total_state" in result1.keys()
RTOL = [1e-6, 1e-4]
ATOL = [1e-8, 1e-6]


@pytest.mark.parametrize("sample_method", SAMPLE_METHODS)
Expand All @@ -134,7 +99,7 @@ def test_sample_no_interventions(
@pytest.mark.parametrize("num_samples", NUM_SAMPLES)
@pytest.mark.parametrize("rtol", RTOL)
@pytest.mark.parametrize("atol", ATOL)
def test_sample_with_tolerance(
def test_sample_no_interventions(
sample_method, model, start_time, end_time, logging_step_size, num_samples, rtol, atol
):
model_url = model.url
Expand Down Expand Up @@ -404,8 +369,10 @@ def test_calibrate_no_kwargs(
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES)
@pytest.mark.parametrize("rtol", RTOL)
@pytest.mark.parametrize("atol", ATOL)
def test_calibrate_deterministic(
model_fixture, start_time, end_time, logging_step_size
model_fixture, start_time, end_time, logging_step_size, rtol, atol
):
model_url = model_fixture.url
(
Expand All @@ -421,6 +388,8 @@ def test_calibrate_deterministic(
"data_mapping": model_fixture.data_mapping,
"start_time": start_time,
"deterministic_learnable_parameters": deterministic_learnable_parameters,
"rtol": rtol,
"atol": atol,
**CALIBRATE_KWARGS,
}

Expand All @@ -440,7 +409,7 @@ def test_calibrate_deterministic(
assert torch.allclose(param_value, param_sample_2[param_name])

result = sample(
*sample_args, **sample_kwargs, inferred_parameters=inferred_parameters
*sample_args, **sample_kwargs, inferred_parameters=inferred_parameters, rtol=rtol, atol=atol
)["unprocessed_result"]

check_result_sizes(result, start_time, end_time, logging_step_size, 1)
Expand Down Expand Up @@ -603,7 +572,9 @@ def test_output_format(
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("num_samples", NUM_SAMPLES)
def test_optimize(model_fixture, start_time, end_time, num_samples):
@pytest.mark.parametrize("rtol", RTOL)
@pytest.mark.parametrize("atol", ATOL)
def test_optimize(model_fixture, start_time, end_time, num_samples, rtol, atol):
logging_step_size = 1.0
model_url = model_fixture.url

Expand All @@ -627,6 +598,8 @@ def __call__(self, x):
"maxiter": 1,
"maxfeval": 2,
"progress_hook": progress_hook,
"rtol": rtol,
"atol": atol
}
bounds_interventions = optimize_kwargs["bounds_interventions"]
opt_result = optimize(
Expand Down Expand Up @@ -665,6 +638,8 @@ def __call__(self, x):
static_parameter_interventions=opt_intervention,
solver_method=optimize_kwargs["solver_method"],
solver_options=optimize_kwargs["solver_options"],
rtol=rtol,
atol=atol
)["unprocessed_result"]

intervened_result_subset = {
Expand Down

0 comments on commit 1680f70

Please sign in to comment.