Skip to content

Commit

Permalink
simulate TA4 interface (#399)
Browse files Browse the repository at this point in the history
* progress towards TA4 interfaces

* minor tidying, still incomplete

* more tidying

* remove attempt at pickling, and added intervention test

* dynamic intervention tests working

* got predictive working

* added additional test

* update types

* update types

* fix upstream time collisions in ChiRho (#400)

* rename simulate to sample

* add pyciemss_logging_wrapper and documentation

* revise tests
  • Loading branch information
SamWitty authored Oct 30, 2023
1 parent a7e5416 commit 99ab914
Show file tree
Hide file tree
Showing 5 changed files with 365 additions and 154 deletions.
283 changes: 139 additions & 144 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,144 +1,139 @@
import functools
from typing import Generic, Optional, TypeVar

# By convention we use "T" to denote the type of the dynamical system, e.g. `ODE`, `PDE`, or `SDE`.
T = TypeVar("T")


class DynamicalSystem(Generic[T]):
"""
A dynamical system is a model of a system that evolves over time.
"""

pass


class Intervention(Generic[T]):
"""
An intervention is a change to a dynamical system that is not a change to the parameters.
"""

pass


class Data(Generic[T]):
"""
Data is a collection of observations of the dynamical system.
"""

pass


class InferredParameters(Generic[T]):
"""
InferredParameters are the parameters of the dynamical system that are inferred from data.
"""

pass


class Simulation(Generic[T]):
"""
A simulation is a collection of trajectories of a dynamical system.
"""

pass


class ObjectiveFunction(Generic[T]):
"""
An objective function is a function that is optimized to infer parameters from data.
"""

pass


class Constraints(Generic[T]):
"""
Constraints are constraints on the parameters of the dynamical system for optimization.
"""

pass


class OptimizationAlgorithm(Generic[T]):
"""
An optimization algorithm is an algorithm that is used to optimize the objective function subject to constraints.
"""

pass


class OptimizationResult(Generic[T]):
"""
An optimization result is the result of optimizing the objective function subject to constraints.
"""

pass


@functools.singledispatch
def setup_model(model: DynamicalSystem[T], *args, **kwargs) -> DynamicalSystem[T]:
"""
Instatiate a model for a particular configuration of initial conditions, boundary conditions, logging events, etc.
"""
raise NotImplementedError


@functools.singledispatch
def reset_model(model: DynamicalSystem[T], *args, **kwargs) -> DynamicalSystem[T]:
"""
Reset a model to its initial state.
reset_model * setup_model = id
"""
raise NotImplementedError


@functools.singledispatch
def intervene(
model: DynamicalSystem[T], intervention: Intervention[T], *args, **kwargs
) -> DynamicalSystem[T]:
"""
`intervene(model, intervention)` returns a new model where the intervention has been applied.
"""
raise NotImplementedError


@functools.singledispatch
def calibrate(
model: DynamicalSystem[T], data: Data[T], *args, **kwargs
) -> InferredParameters[T]:
"""
Infer parameters for a DynamicalSystem model conditional on data.
This is typically done using a variational approximation.
"""
raise NotImplementedError


@functools.singledispatch
def simulate(
model: DynamicalSystem[T],
inferred_parameters: Optional[InferredParameters[T]] = None,
*args,
**kwargs
) -> Simulation[T]:
"""
Simulate trajectories from a given `model`, conditional on specified `inferred_parameters` distribution.
If `inferred_parameters` is not given, this will sample from the prior distribution.
"""
raise NotImplementedError


@functools.singledispatch
def optimize(
model: DynamicalSystem[T],
objective_function: ObjectiveFunction[T],
constraints: Constraints[T],
optimization_algorithm: OptimizationAlgorithm[T],
*args,
**kwargs
) -> OptimizationResult[T]:
"""
Optimize the objective function subject to the constraints.
"""
raise NotImplementedError
import contextlib
from typing import Any, Callable, Dict, Optional, Union

import pyro
import torch
from chirho.dynamical.handlers import (
DynamicIntervention,
InterruptionEventLoop,
LogTrajectory,
StaticIntervention,
)
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.dynamical.ops import State

from pyciemss.compiled_dynamics import CompiledDynamics
from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper


@pyciemss_logging_wrapper
def sample(
model_path_or_json: Union[str, Dict],
end_time: float,
logging_step_size: float,
num_samples: int,
*,
solver_method="dopri5",
solver_options: Dict[str, Any] = {},
start_time: float = 0.0,
inferred_parameters: Optional[pyro.nn.PyroModule] = None,
static_interventions: Dict[float, Dict[str, torch.Tensor]] = {},
dynamic_interventions: Dict[
Callable[[Dict[str, torch.Tensor]], torch.Tensor], Dict[str, torch.Tensor]
] = {},
) -> Dict[str, torch.Tensor]:
"""
Load a model from a file, compile it into a probabilistic program, and sample from it.
Args:
model_path_or_json: Union[str, Dict]
- A path to a AMR model file or JSON containing a model in AMR form.
end_time: float
- The end time of the sampled simulation.
logging_step_size: float
- The step size to use for logging the trajectory.
num_samples: int
- The number of samples to draw from the model.
interventions: Optional[Iterable[Tuple[float, str, float]]]
- A list of interventions to apply to the model.
Each intervention is a tuple of the form (time, parameter_name, value).
solver_method: str
- The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details.
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
- By default we set the `start_time` to be 0.
inferred_parameters:
- A Pyro module that contains the inferred parameters of the model.
This is typically the result of `calibrate`.
- If not provided, we will use the default values from the AMR model.
static_interventions: Dict[float, Dict[str, torch.Tensor]]
- A dictionary of static interventions to apply to the model.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {state_variable_name: value}.
dynamic_interventions: Dict[Callable[[Dict[str, torch.Tensor]], torch.Tensor], Dict[str, torch.Tensor]]
- A dictionary of dynamic interventions to apply to the model.
- Each key is a function that takes in the current state of the model and returns a tensor.
When this function crosses 0, the dynamic intervention is applied.
- Each value is a dictionary of the form {state_variable_name: value}.
Returns:
result: Dict[str, torch.Tensor]
- Dictionary of outputs from the model.
- Each key is the name of a parameter or state variable in the model.
- Each value is a tensor of shape (num_samples, num_timepoints) for state variables
and (num_samples,) for parameters.
"""

model = CompiledDynamics.load(model_path_or_json)

timespan = torch.arange(start_time, end_time, logging_step_size)

static_intervention_handlers = [
StaticIntervention(time, State(**static_intervention_assignment))
for time, static_intervention_assignment in static_interventions.items()
]
dynamic_intervention_handlers = [
DynamicIntervention(event_fn, State(**dynamic_intervention_assignment))
for event_fn, dynamic_intervention_assignment in dynamic_interventions.items()
]

def wrapped_model():
with LogTrajectory(timespan) as lt:
with InterruptionEventLoop():
with contextlib.ExitStack() as stack:
for handler in (
static_intervention_handlers + dynamic_intervention_handlers
):
stack.enter_context(handler)
model(
torch.as_tensor(start_time),
torch.tensor(end_time),
TorchDiffEq(method=solver_method, options=solver_options),
)
# Adding deterministic nodes to the model so that we can access the trajectory in the Predictive object.
[pyro.deterministic(f"state_{k}", v) for k, v in lt.trajectory.items()]

return pyro.infer.Predictive(
wrapped_model, guide=inferred_parameters, num_samples=num_samples
)()


# # TODO
# def calibrate(
# model: CompiledDynamics, data: Data, *args, **kwargs
# ) -> pyro.nn.PyroModule:
# """
# Infer parameters for a DynamicalSystem model conditional on data.
# This is typically done using a variational approximation.
# """
# raise NotImplementedError


# # TODO
# def optimize(
# model: CompiledDynamics,
# objective_function: ObjectiveFunction,
# constraints: Constraints,
# optimization_algorithm: OptimizationAlgorithm,
# *args,
# **kwargs
# ) -> OptimizationResult:
# """
# Optimize the objective function subject to the constraints.
# """
# raise NotImplementedError
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ install_requires =
jupyter
torch >= 1.8.0
mira @ git+https://github.com/indralab/[email protected]
chirho @ git+https://github.com/BasisResearch/chirho@f44731d416c20cbf1615147f5e0ba6ef3afed78d
chirho @ git+https://github.com/BasisResearch/chirho@f3019d4b22f4e49261efbf8da90a30095af2afbc
sympytorch
torchdiffeq

Expand Down
54 changes: 52 additions & 2 deletions tests/model_fixtures.py → tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Dict, TypeVar

import torch

T = TypeVar("T")

# SEE https://github.com/DARPA-ASKEM/Model-Representations/issues/62 for discussion of valid models.

PETRI_URLS = [
Expand All @@ -23,5 +27,51 @@

MODEL_URLS = PETRI_URLS + REGNET_URLS + STOCKFLOW_URLS

START_TIMES = [torch.tensor(0.0), torch.tensor(1.0), torch.tensor(2.0)]
END_TIMES = [torch.tensor(3.0), torch.tensor(4.0), torch.tensor(5.0)]
START_TIMES = [0.0, 1.0, 2.0]
END_TIMES = [3.0, 4.0, 5.0]

LOGGING_STEP_SIZES = [0.1]

NUM_SAMPLES = [2]


def check_keys_match(obj1: Dict[str, T], obj2: Dict[str, T]):
assert set(obj1.keys()) == set(obj2.keys()), "Objects have different variables."
return True


def check_states_match_in_all_but_values(
traj1: Dict[str, torch.Tensor], traj2: Dict[str, torch.Tensor]
):
assert check_keys_match(traj1, traj2)

for k in traj1.keys():
if k[:5] == "state":
assert not torch.allclose(
traj2[k], traj1[k]
), f"Trajectories are identical in state trajectory of variable {k}, but should differ."

return True


def check_result_sizes(
traj: Dict[str, torch.Tensor],
start_time: float,
end_time: float,
logging_step_size: float,
num_samples: int,
):
for k, v in traj.items():
assert isinstance(k, str)
assert isinstance(v, torch.Tensor)

if k[:5] == "state":
assert v.shape == (
num_samples,
len(torch.arange(start_time, end_time, logging_step_size))
- 1, # Does not include start_time
)
else:
assert v.shape == (num_samples,)

return True
Loading

0 comments on commit 99ab914

Please sign in to comment.