diff --git a/pyciemss/interfaces.py b/pyciemss/interfaces.py index 2df8313b3..ead4a628b 100644 --- a/pyciemss/interfaces.py +++ b/pyciemss/interfaces.py @@ -8,6 +8,7 @@ InterruptionEventLoop, LogTrajectory, StaticIntervention, + StaticBatchObservation, ) from chirho.dynamical.handlers.solver import TorchDiffEq from chirho.dynamical.ops import State @@ -84,7 +85,7 @@ def sample( model = CompiledDynamics.load(model_path_or_json) - timespan = torch.arange(start_time, end_time, logging_step_size) + timespan = torch.arange(start_time+logging_step_size, end_time, logging_step_size) static_intervention_handlers = [ StaticIntervention(time, State(**static_intervention_assignment)) @@ -124,6 +125,7 @@ def wrapped_model(): def calibrate( model_path_or_json: Union[str, Dict], data: dict[str, torch.Tensor], + data_timepoints: torch.Tensor, start_time: float, *, noise_model: str = "normal", @@ -160,9 +162,6 @@ def autoguide(model): ) return guide - # TODO - end_time = ... - static_intervention_handlers = [ StaticIntervention(time, State(**static_intervention_assignment)) for time, static_intervention_assignment in static_interventions.items() @@ -173,13 +172,18 @@ def autoguide(model): ] def wrapped_model(): - with InterruptionEventLoop(): - with contextlib.ExitStack() as stack: - for handler in ( - static_intervention_handlers + dynamic_intervention_handlers - ): - stack.enter_context(handler) - model(torch.as_tensor(start_time), torch.as_tensor(end_time)) + + # TODO: pick up here. + obs = chirho.condition() + + with StaticBatchObservation(): + with InterruptionEventLoop(): + with contextlib.ExitStack() as stack: + for handler in ( + static_intervention_handlers + dynamic_intervention_handlers + ): + stack.enter_context(handler) + model(torch.as_tensor(start_time), torch.as_tensor(data_timepoints[-1])) guide = autoguide(wrapped_model) optim = pyro.optim.Adam({"lr": lr}) diff --git a/pyciemss/mira_integration/compiled_dynamics.py b/pyciemss/mira_integration/compiled_dynamics.py index 9cd8748a4..806c75b1f 100644 --- a/pyciemss/mira_integration/compiled_dynamics.py +++ b/pyciemss/mira_integration/compiled_dynamics.py @@ -67,6 +67,9 @@ def _compile_param_values_mira( for param_info in src.parameters.values(): param_name = get_name(param_info) + if param_info.placeholder: + continue + param_dist = getattr(param_info, "distribution", None) if param_dist is None: param_value = param_info.value diff --git a/tests/fixtures.py b/tests/fixtures.py index ad77c11c0..ad4512f8b 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -12,8 +12,8 @@ # "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/ont_pop_vax.json", # "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir.json", # "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir_flux_span.json", - "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir_typed.json", - "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir_typed_aug.json", + # "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir_typed.json", + # "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir_typed_aug.json", ] REGNET_URLS = [ @@ -22,7 +22,7 @@ ] STOCKFLOW_URLS = [ - # "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/stockflow/examples/sir.json" + "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/stockflow/examples/sir.json" ] MODEL_URLS = PETRI_URLS + REGNET_URLS + STOCKFLOW_URLS diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index c3236dd3e..0561be7b5 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -229,4 +229,6 @@ def test_calibrate(model_url, start_time, end_time, logging_step_size): noise_model_kwargs={"scale": 0.1}, ) + data_timespan = torch.arange(start_time+logging_step_size, end_time, logging_step_size) + assert isinstance(data, dict)