diff --git a/pyciemss/mira_integration/compiled_dynamics.py b/pyciemss/mira_integration/compiled_dynamics.py index 27b24b1a..1584ad93 100644 --- a/pyciemss/mira_integration/compiled_dynamics.py +++ b/pyciemss/mira_integration/compiled_dynamics.py @@ -102,7 +102,7 @@ def _compile_param_values_mira( elif isinstance(param_value, pyro.distributions.Distribution): values[param_name] = pyro.nn.PyroSample(param_value) elif isinstance(param_value, (numbers.Number, numpy.ndarray, torch.Tensor)): - values[param_name] = torch.as_tensor(param_value) + values[param_name] = torch.as_tensor(param_value, dtype=torch.float32) else: raise TypeError(f"Unknown parameter type: {type(param_value)}") @@ -126,7 +126,7 @@ def _eval_deriv_mira( dX: State[torch.Tensor] = dict() for i, var in enumerate(src.variables.values()): k = get_name(var) - dX[k] = numeric_deriv[..., i] + dX[k] = numeric_deriv[..., i].float() return dX @@ -146,7 +146,7 @@ def _eval_initial_state_mira( X: State[torch.Tensor] = dict() for i, var in enumerate(src.variables.values()): k = get_name(var) - X[k] = numeric_initial_state[..., i] + X[k] = numeric_initial_state[..., i].float() return X @@ -179,7 +179,7 @@ def _eval_observables_mira( observables: State[torch.Tensor] = dict() for i, obs in enumerate(src.observables.values()): k = get_name(obs) - observables[k] = numeric_observables[..., i] + observables[k] = numeric_observables[..., i].float() return observables diff --git a/tests/fixtures.py b/tests/fixtures.py index 78dd5c46..afccb115 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -30,6 +30,7 @@ def __init__( data_mapping: Dict[str, str] = {}, data_mapped_to_observable: bool = False, optimize_kwargs: Dict[str, Any] = None, + has_distributional_parameters: bool = True, ): self.url = url self.important_parameter = important_parameter @@ -37,6 +38,7 @@ def __init__( self.data_mapping = data_mapping self.data_mapped_to_observable = data_mapped_to_observable self.optimize_kwargs = optimize_kwargs + self.has_distributional_parameters = has_distributional_parameters # See https://github.com/DARPA-ASKEM/Model-Representations/issues/62 for discussion of valid models. @@ -85,7 +87,11 @@ def __init__( ModelFixture( os.path.join(MODELS_PATH, "LV_rabbits_wolves_model03_regnet.json"), "beta" ), - # ModelFixture(os.path.join(MODELS_PATH, "LV_goat_chupacabra_regnet.json"), "beta"), + ModelFixture( + os.path.join(MODELS_PATH, "LacOperon.json"), + "k_1", + has_distributional_parameters=False, + ), ] STOCKFLOW_MODELS = [ diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 113c42c0..1f653b06 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -89,14 +89,16 @@ def setup_calibrate(model_fixture, start_time, end_time, logging_step_size): @pytest.mark.parametrize("sample_method", SAMPLE_METHODS) -@pytest.mark.parametrize("model_url", MODEL_URLS) +@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("start_time", START_TIMES) @pytest.mark.parametrize("end_time", END_TIMES) @pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) @pytest.mark.parametrize("num_samples", NUM_SAMPLES) def test_sample_no_interventions( - sample_method, model_url, start_time, end_time, logging_step_size, num_samples + sample_method, model, start_time, end_time, logging_step_size, num_samples ): + model_url = model.url + with pyro.poutine.seed(rng_seed=0): result1 = sample_method( model_url, end_time, logging_step_size, num_samples, start_time=start_time @@ -115,7 +117,8 @@ def test_sample_no_interventions( check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) check_states_match(result1, result2) - check_states_match_in_all_but_values(result1, result3) + if model.has_distributional_parameters: + check_states_match_in_all_but_values(result1, result3) if sample_method.__name__ == "dummy_ensemble_sample": assert "total_state" in result1.keys()