From db587da509226900631f6d03d6026f7cc4df919d Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Wed, 5 Jun 2024 13:32:57 -0400 Subject: [PATCH 1/5] add parameters to observables --- pyciemss/mira_integration/compiled_dynamics.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pyciemss/mira_integration/compiled_dynamics.py b/pyciemss/mira_integration/compiled_dynamics.py index 2cb668df6..5f2dcc9eb 100644 --- a/pyciemss/mira_integration/compiled_dynamics.py +++ b/pyciemss/mira_integration/compiled_dynamics.py @@ -158,8 +158,14 @@ def _eval_observables_mira( ) -> State[torch.Tensor]: if len(src.observables) == 0: return dict() + + parameters = { + get_name(param_info): getattr(param_module, get_name(param_info)) + for param_info in src.parameters.values() + if not param_info.placeholder + } - numeric_observables = param_module.numeric_observables_func(**X) + numeric_observables = param_module.numeric_observables_func(**X, **parameters) observables: State[torch.Tensor] = dict() for i, obs in enumerate(src.observables.values()): From 8a41dfd99aac4b81d3f774d36bd68608628b3c47 Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Wed, 5 Jun 2024 15:07:00 -0400 Subject: [PATCH 2/5] lint --- pyciemss/mira_integration/compiled_dynamics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyciemss/mira_integration/compiled_dynamics.py b/pyciemss/mira_integration/compiled_dynamics.py index 5f2dcc9eb..fc6f22500 100644 --- a/pyciemss/mira_integration/compiled_dynamics.py +++ b/pyciemss/mira_integration/compiled_dynamics.py @@ -158,7 +158,7 @@ def _eval_observables_mira( ) -> State[torch.Tensor]: if len(src.observables) == 0: return dict() - + parameters = { get_name(param_info): getattr(param_module, get_name(param_info)) for param_info in src.parameters.values() From 9aad7f0835fcb7afc336567da7c3b630995f06a6 Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Wed, 5 Jun 2024 16:03:38 -0400 Subject: [PATCH 3/5] added test --- tests/fixtures.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/fixtures.py b/tests/fixtures.py index 06a627228..5f0d79bab 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -67,6 +67,13 @@ def __init__( ), "u", ), + ModelFixture( + os.path.join(MODELS_PATH, "SIR_param_in_observables.json"), + "beta", + os.path.join(DATA_PATH, "SIR_data_case_hosp.csv"), + {"case": "incident_cases", "hosp": "I"}, + True, + ), ] REGNET_MODELS = [ From 272540312488a244d40b48132c47669f417c1d7f Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Wed, 5 Jun 2024 16:04:10 -0400 Subject: [PATCH 4/5] progress towards broadcasting --- pyciemss/mira_integration/compiled_dynamics.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyciemss/mira_integration/compiled_dynamics.py b/pyciemss/mira_integration/compiled_dynamics.py index fc6f22500..59c50db92 100644 --- a/pyciemss/mira_integration/compiled_dynamics.py +++ b/pyciemss/mira_integration/compiled_dynamics.py @@ -165,7 +165,11 @@ def _eval_observables_mira( if not param_info.placeholder } - numeric_observables = param_module.numeric_observables_func(**X, **parameters) + # TODO: support event_dim > 0 upstream in ChiRho + # Default to time being the rightmost dimension + parameters_expanded = {k: torch.unsqueeze(v, -1) if len(v.size()) > 0 else v for k, v in parameters.items()} + + numeric_observables = param_module.numeric_observables_func(**X, **parameters_expanded) observables: State[torch.Tensor] = dict() for i, obs in enumerate(src.observables.values()): From 0264fa38f9b37acf958b63c77fea0e37950d5ec8 Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Wed, 5 Jun 2024 16:35:43 -0400 Subject: [PATCH 5/5] lint --- pyciemss/mira_integration/compiled_dynamics.py | 9 +++++++-- tests/fixtures.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pyciemss/mira_integration/compiled_dynamics.py b/pyciemss/mira_integration/compiled_dynamics.py index 59c50db92..27b24b1a3 100644 --- a/pyciemss/mira_integration/compiled_dynamics.py +++ b/pyciemss/mira_integration/compiled_dynamics.py @@ -167,9 +167,14 @@ def _eval_observables_mira( # TODO: support event_dim > 0 upstream in ChiRho # Default to time being the rightmost dimension - parameters_expanded = {k: torch.unsqueeze(v, -1) if len(v.size()) > 0 else v for k, v in parameters.items()} + parameters_expanded = { + k: torch.unsqueeze(v, -1) if len(v.size()) > 0 else v + for k, v in parameters.items() + } - numeric_observables = param_module.numeric_observables_func(**X, **parameters_expanded) + numeric_observables = param_module.numeric_observables_func( + **X, **parameters_expanded + ) observables: State[torch.Tensor] = dict() for i, obs in enumerate(src.observables.values()): diff --git a/tests/fixtures.py b/tests/fixtures.py index 5f0d79bab..9ea0cd875 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -67,11 +67,11 @@ def __init__( ), "u", ), - ModelFixture( + ModelFixture( os.path.join(MODELS_PATH, "SIR_param_in_observables.json"), "beta", os.path.join(DATA_PATH, "SIR_data_case_hosp.csv"), - {"case": "incident_cases", "hosp": "I"}, + {"case": "incident_cases", "hosp": "I"}, True, ), ]