diff --git a/pyciemss/mira_integration/compiled_dynamics.py b/pyciemss/mira_integration/compiled_dynamics.py index 2cb668df6..27b24b1a3 100644 --- a/pyciemss/mira_integration/compiled_dynamics.py +++ b/pyciemss/mira_integration/compiled_dynamics.py @@ -159,7 +159,22 @@ def _eval_observables_mira( if len(src.observables) == 0: return dict() - numeric_observables = param_module.numeric_observables_func(**X) + parameters = { + get_name(param_info): getattr(param_module, get_name(param_info)) + for param_info in src.parameters.values() + if not param_info.placeholder + } + + # TODO: support event_dim > 0 upstream in ChiRho + # Default to time being the rightmost dimension + parameters_expanded = { + k: torch.unsqueeze(v, -1) if len(v.size()) > 0 else v + for k, v in parameters.items() + } + + numeric_observables = param_module.numeric_observables_func( + **X, **parameters_expanded + ) observables: State[torch.Tensor] = dict() for i, obs in enumerate(src.observables.values()): diff --git a/tests/fixtures.py b/tests/fixtures.py index 06a627228..9ea0cd875 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -67,6 +67,13 @@ def __init__( ), "u", ), + ModelFixture( + os.path.join(MODELS_PATH, "SIR_param_in_observables.json"), + "beta", + os.path.join(DATA_PATH, "SIR_data_case_hosp.csv"), + {"case": "incident_cases", "hosp": "I"}, + True, + ), ] REGNET_MODELS = [