Skip to content

Commit

Permalink
Coerce AMR inputs to float (#604)
Browse files Browse the repository at this point in the history
* add failing test case

* lint

* coerce datatypes

* move LacOperon model to main

* remove important parameter for LacOperon

* make distributional parameters optional

* lint
  • Loading branch information
SamWitty authored Sep 3, 2024
1 parent e224fd5 commit 221c424
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
8 changes: 4 additions & 4 deletions pyciemss/mira_integration/compiled_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _compile_param_values_mira(
elif isinstance(param_value, pyro.distributions.Distribution):
values[param_name] = pyro.nn.PyroSample(param_value)
elif isinstance(param_value, (numbers.Number, numpy.ndarray, torch.Tensor)):
values[param_name] = torch.as_tensor(param_value)
values[param_name] = torch.as_tensor(param_value, dtype=torch.float32)
else:
raise TypeError(f"Unknown parameter type: {type(param_value)}")

Expand All @@ -126,7 +126,7 @@ def _eval_deriv_mira(
dX: State[torch.Tensor] = dict()
for i, var in enumerate(src.variables.values()):
k = get_name(var)
dX[k] = numeric_deriv[..., i]
dX[k] = numeric_deriv[..., i].float()
return dX


Expand All @@ -146,7 +146,7 @@ def _eval_initial_state_mira(
X: State[torch.Tensor] = dict()
for i, var in enumerate(src.variables.values()):
k = get_name(var)
X[k] = numeric_initial_state[..., i]
X[k] = numeric_initial_state[..., i].float()
return X


Expand Down Expand Up @@ -179,7 +179,7 @@ def _eval_observables_mira(
observables: State[torch.Tensor] = dict()
for i, obs in enumerate(src.observables.values()):
k = get_name(obs)
observables[k] = numeric_observables[..., i]
observables[k] = numeric_observables[..., i].float()

return observables

Expand Down
8 changes: 7 additions & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ def __init__(
data_mapping: Dict[str, str] = {},
data_mapped_to_observable: bool = False,
optimize_kwargs: Dict[str, Any] = None,
has_distributional_parameters: bool = True,
):
self.url = url
self.important_parameter = important_parameter
self.data_path = data_path
self.data_mapping = data_mapping
self.data_mapped_to_observable = data_mapped_to_observable
self.optimize_kwargs = optimize_kwargs
self.has_distributional_parameters = has_distributional_parameters


# See https://github.com/DARPA-ASKEM/Model-Representations/issues/62 for discussion of valid models.
Expand Down Expand Up @@ -85,7 +87,11 @@ def __init__(
ModelFixture(
os.path.join(MODELS_PATH, "LV_rabbits_wolves_model03_regnet.json"), "beta"
),
# ModelFixture(os.path.join(MODELS_PATH, "LV_goat_chupacabra_regnet.json"), "beta"),
ModelFixture(
os.path.join(MODELS_PATH, "LacOperon.json"),
"k_1",
has_distributional_parameters=False,
),
]

STOCKFLOW_MODELS = [
Expand Down
9 changes: 6 additions & 3 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,16 @@ def setup_calibrate(model_fixture, start_time, end_time, logging_step_size):


@pytest.mark.parametrize("sample_method", SAMPLE_METHODS)
@pytest.mark.parametrize("model_url", MODEL_URLS)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES)
@pytest.mark.parametrize("num_samples", NUM_SAMPLES)
def test_sample_no_interventions(
sample_method, model_url, start_time, end_time, logging_step_size, num_samples
sample_method, model, start_time, end_time, logging_step_size, num_samples
):
model_url = model.url

with pyro.poutine.seed(rng_seed=0):
result1 = sample_method(
model_url, end_time, logging_step_size, num_samples, start_time=start_time
Expand All @@ -115,7 +117,8 @@ def test_sample_no_interventions(
check_result_sizes(result, start_time, end_time, logging_step_size, num_samples)

check_states_match(result1, result2)
check_states_match_in_all_but_values(result1, result3)
if model.has_distributional_parameters:
check_states_match_in_all_but_values(result1, result3)

if sample_method.__name__ == "dummy_ensemble_sample":
assert "total_state" in result1.keys()
Expand Down

0 comments on commit 221c424

Please sign in to comment.