Skip to content

Commit

Permalink
more progress
Browse files Browse the repository at this point in the history
  • Loading branch information
SamWitty committed Nov 1, 2023
1 parent 4f0733b commit c4bef92
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 14 deletions.
26 changes: 15 additions & 11 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
InterruptionEventLoop,
LogTrajectory,
StaticIntervention,
StaticBatchObservation,
)
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.dynamical.ops import State
Expand Down Expand Up @@ -84,7 +85,7 @@ def sample(

model = CompiledDynamics.load(model_path_or_json)

timespan = torch.arange(start_time, end_time, logging_step_size)
timespan = torch.arange(start_time+logging_step_size, end_time, logging_step_size)

static_intervention_handlers = [
StaticIntervention(time, State(**static_intervention_assignment))
Expand Down Expand Up @@ -124,6 +125,7 @@ def wrapped_model():
def calibrate(
model_path_or_json: Union[str, Dict],
data: dict[str, torch.Tensor],
data_timepoints: torch.Tensor,
start_time: float,
*,
noise_model: str = "normal",
Expand Down Expand Up @@ -160,9 +162,6 @@ def autoguide(model):
)
return guide

# TODO
end_time = ...

static_intervention_handlers = [
StaticIntervention(time, State(**static_intervention_assignment))
for time, static_intervention_assignment in static_interventions.items()
Expand All @@ -173,13 +172,18 @@ def autoguide(model):
]

def wrapped_model():
with InterruptionEventLoop():
with contextlib.ExitStack() as stack:
for handler in (
static_intervention_handlers + dynamic_intervention_handlers
):
stack.enter_context(handler)
model(torch.as_tensor(start_time), torch.as_tensor(end_time))

# TODO: pick up here.
obs = chirho.condition()

with StaticBatchObservation():
with InterruptionEventLoop():
with contextlib.ExitStack() as stack:
for handler in (
static_intervention_handlers + dynamic_intervention_handlers
):
stack.enter_context(handler)
model(torch.as_tensor(start_time), torch.as_tensor(data_timepoints[-1]))

guide = autoguide(wrapped_model)
optim = pyro.optim.Adam({"lr": lr})
Expand Down
3 changes: 3 additions & 0 deletions pyciemss/mira_integration/compiled_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def _compile_param_values_mira(
for param_info in src.parameters.values():
param_name = get_name(param_info)

if param_info.placeholder:
continue

param_dist = getattr(param_info, "distribution", None)
if param_dist is None:
param_value = param_info.value
Expand Down
6 changes: 3 additions & 3 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/ont_pop_vax.json",
# "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir.json",
# "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir_flux_span.json",
"https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir_typed.json",
"https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir_typed_aug.json",
# "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir_typed.json",
# "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir_typed_aug.json",
]

REGNET_URLS = [
Expand All @@ -22,7 +22,7 @@
]

STOCKFLOW_URLS = [
# "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/stockflow/examples/sir.json"
"https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/stockflow/examples/sir.json"
]

MODEL_URLS = PETRI_URLS + REGNET_URLS + STOCKFLOW_URLS
Expand Down
2 changes: 2 additions & 0 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,6 @@ def test_calibrate(model_url, start_time, end_time, logging_step_size):
noise_model_kwargs={"scale": 0.1},
)

data_timespan = torch.arange(start_time+logging_step_size, end_time, logging_step_size)

assert isinstance(data, dict)

0 comments on commit c4bef92

Please sign in to comment.