-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* progress towards TA4 interfaces * minor tidying, still incomplete * more tidying * remove attempt at pickling, and added intervention test * dynamic intervention tests working * got predictive working * added additional test * update types * update types * fix upstream time collisions in ChiRho (#400) * rename simulate to sample * add pyciemss_logging_wrapper and documentation * revise tests
- Loading branch information
Showing
5 changed files
with
365 additions
and
154 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,144 +1,139 @@ | ||
import functools | ||
from typing import Generic, Optional, TypeVar | ||
|
||
# By convention we use "T" to denote the type of the dynamical system, e.g. `ODE`, `PDE`, or `SDE`. | ||
T = TypeVar("T") | ||
|
||
|
||
class DynamicalSystem(Generic[T]): | ||
""" | ||
A dynamical system is a model of a system that evolves over time. | ||
""" | ||
|
||
pass | ||
|
||
|
||
class Intervention(Generic[T]): | ||
""" | ||
An intervention is a change to a dynamical system that is not a change to the parameters. | ||
""" | ||
|
||
pass | ||
|
||
|
||
class Data(Generic[T]): | ||
""" | ||
Data is a collection of observations of the dynamical system. | ||
""" | ||
|
||
pass | ||
|
||
|
||
class InferredParameters(Generic[T]): | ||
""" | ||
InferredParameters are the parameters of the dynamical system that are inferred from data. | ||
""" | ||
|
||
pass | ||
|
||
|
||
class Simulation(Generic[T]): | ||
""" | ||
A simulation is a collection of trajectories of a dynamical system. | ||
""" | ||
|
||
pass | ||
|
||
|
||
class ObjectiveFunction(Generic[T]): | ||
""" | ||
An objective function is a function that is optimized to infer parameters from data. | ||
""" | ||
|
||
pass | ||
|
||
|
||
class Constraints(Generic[T]): | ||
""" | ||
Constraints are constraints on the parameters of the dynamical system for optimization. | ||
""" | ||
|
||
pass | ||
|
||
|
||
class OptimizationAlgorithm(Generic[T]): | ||
""" | ||
An optimization algorithm is an algorithm that is used to optimize the objective function subject to constraints. | ||
""" | ||
|
||
pass | ||
|
||
|
||
class OptimizationResult(Generic[T]): | ||
""" | ||
An optimization result is the result of optimizing the objective function subject to constraints. | ||
""" | ||
|
||
pass | ||
|
||
|
||
@functools.singledispatch | ||
def setup_model(model: DynamicalSystem[T], *args, **kwargs) -> DynamicalSystem[T]: | ||
""" | ||
Instatiate a model for a particular configuration of initial conditions, boundary conditions, logging events, etc. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
@functools.singledispatch | ||
def reset_model(model: DynamicalSystem[T], *args, **kwargs) -> DynamicalSystem[T]: | ||
""" | ||
Reset a model to its initial state. | ||
reset_model * setup_model = id | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
@functools.singledispatch | ||
def intervene( | ||
model: DynamicalSystem[T], intervention: Intervention[T], *args, **kwargs | ||
) -> DynamicalSystem[T]: | ||
""" | ||
`intervene(model, intervention)` returns a new model where the intervention has been applied. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
@functools.singledispatch | ||
def calibrate( | ||
model: DynamicalSystem[T], data: Data[T], *args, **kwargs | ||
) -> InferredParameters[T]: | ||
""" | ||
Infer parameters for a DynamicalSystem model conditional on data. | ||
This is typically done using a variational approximation. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
@functools.singledispatch | ||
def simulate( | ||
model: DynamicalSystem[T], | ||
inferred_parameters: Optional[InferredParameters[T]] = None, | ||
*args, | ||
**kwargs | ||
) -> Simulation[T]: | ||
""" | ||
Simulate trajectories from a given `model`, conditional on specified `inferred_parameters` distribution. | ||
If `inferred_parameters` is not given, this will sample from the prior distribution. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
@functools.singledispatch | ||
def optimize( | ||
model: DynamicalSystem[T], | ||
objective_function: ObjectiveFunction[T], | ||
constraints: Constraints[T], | ||
optimization_algorithm: OptimizationAlgorithm[T], | ||
*args, | ||
**kwargs | ||
) -> OptimizationResult[T]: | ||
""" | ||
Optimize the objective function subject to the constraints. | ||
""" | ||
raise NotImplementedError | ||
import contextlib | ||
from typing import Any, Callable, Dict, Optional, Union | ||
|
||
import pyro | ||
import torch | ||
from chirho.dynamical.handlers import ( | ||
DynamicIntervention, | ||
InterruptionEventLoop, | ||
LogTrajectory, | ||
StaticIntervention, | ||
) | ||
from chirho.dynamical.handlers.solver import TorchDiffEq | ||
from chirho.dynamical.ops import State | ||
|
||
from pyciemss.compiled_dynamics import CompiledDynamics | ||
from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper | ||
|
||
|
||
@pyciemss_logging_wrapper | ||
def sample( | ||
model_path_or_json: Union[str, Dict], | ||
end_time: float, | ||
logging_step_size: float, | ||
num_samples: int, | ||
*, | ||
solver_method="dopri5", | ||
solver_options: Dict[str, Any] = {}, | ||
start_time: float = 0.0, | ||
inferred_parameters: Optional[pyro.nn.PyroModule] = None, | ||
static_interventions: Dict[float, Dict[str, torch.Tensor]] = {}, | ||
dynamic_interventions: Dict[ | ||
Callable[[Dict[str, torch.Tensor]], torch.Tensor], Dict[str, torch.Tensor] | ||
] = {}, | ||
) -> Dict[str, torch.Tensor]: | ||
""" | ||
Load a model from a file, compile it into a probabilistic program, and sample from it. | ||
Args: | ||
model_path_or_json: Union[str, Dict] | ||
- A path to a AMR model file or JSON containing a model in AMR form. | ||
end_time: float | ||
- The end time of the sampled simulation. | ||
logging_step_size: float | ||
- The step size to use for logging the trajectory. | ||
num_samples: int | ||
- The number of samples to draw from the model. | ||
interventions: Optional[Iterable[Tuple[float, str, float]]] | ||
- A list of interventions to apply to the model. | ||
Each intervention is a tuple of the form (time, parameter_name, value). | ||
solver_method: str | ||
- The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details. | ||
- If performance is incredibly slow, we suggest using `euler` to debug. | ||
If using `euler` results in faster simulation, the issue is likely that the model is stiff. | ||
solver_options: Dict[str, Any] | ||
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details. | ||
start_time: float | ||
- The start time of the model. This is used to align the `start_state` from the | ||
AMR model with the simulation timepoints. | ||
- By default we set the `start_time` to be 0. | ||
inferred_parameters: | ||
- A Pyro module that contains the inferred parameters of the model. | ||
This is typically the result of `calibrate`. | ||
- If not provided, we will use the default values from the AMR model. | ||
static_interventions: Dict[float, Dict[str, torch.Tensor]] | ||
- A dictionary of static interventions to apply to the model. | ||
- Each key is the time at which the intervention is applied. | ||
- Each value is a dictionary of the form {state_variable_name: value}. | ||
dynamic_interventions: Dict[Callable[[Dict[str, torch.Tensor]], torch.Tensor], Dict[str, torch.Tensor]] | ||
- A dictionary of dynamic interventions to apply to the model. | ||
- Each key is a function that takes in the current state of the model and returns a tensor. | ||
When this function crosses 0, the dynamic intervention is applied. | ||
- Each value is a dictionary of the form {state_variable_name: value}. | ||
Returns: | ||
result: Dict[str, torch.Tensor] | ||
- Dictionary of outputs from the model. | ||
- Each key is the name of a parameter or state variable in the model. | ||
- Each value is a tensor of shape (num_samples, num_timepoints) for state variables | ||
and (num_samples,) for parameters. | ||
""" | ||
|
||
model = CompiledDynamics.load(model_path_or_json) | ||
|
||
timespan = torch.arange(start_time, end_time, logging_step_size) | ||
|
||
static_intervention_handlers = [ | ||
StaticIntervention(time, State(**static_intervention_assignment)) | ||
for time, static_intervention_assignment in static_interventions.items() | ||
] | ||
dynamic_intervention_handlers = [ | ||
DynamicIntervention(event_fn, State(**dynamic_intervention_assignment)) | ||
for event_fn, dynamic_intervention_assignment in dynamic_interventions.items() | ||
] | ||
|
||
def wrapped_model(): | ||
with LogTrajectory(timespan) as lt: | ||
with InterruptionEventLoop(): | ||
with contextlib.ExitStack() as stack: | ||
for handler in ( | ||
static_intervention_handlers + dynamic_intervention_handlers | ||
): | ||
stack.enter_context(handler) | ||
model( | ||
torch.as_tensor(start_time), | ||
torch.tensor(end_time), | ||
TorchDiffEq(method=solver_method, options=solver_options), | ||
) | ||
# Adding deterministic nodes to the model so that we can access the trajectory in the Predictive object. | ||
[pyro.deterministic(f"state_{k}", v) for k, v in lt.trajectory.items()] | ||
|
||
return pyro.infer.Predictive( | ||
wrapped_model, guide=inferred_parameters, num_samples=num_samples | ||
)() | ||
|
||
|
||
# # TODO | ||
# def calibrate( | ||
# model: CompiledDynamics, data: Data, *args, **kwargs | ||
# ) -> pyro.nn.PyroModule: | ||
# """ | ||
# Infer parameters for a DynamicalSystem model conditional on data. | ||
# This is typically done using a variational approximation. | ||
# """ | ||
# raise NotImplementedError | ||
|
||
|
||
# # TODO | ||
# def optimize( | ||
# model: CompiledDynamics, | ||
# objective_function: ObjectiveFunction, | ||
# constraints: Constraints, | ||
# optimization_algorithm: OptimizationAlgorithm, | ||
# *args, | ||
# **kwargs | ||
# ) -> OptimizationResult: | ||
# """ | ||
# Optimize the objective function subject to the constraints. | ||
# """ | ||
# raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ install_requires = | |
jupyter | ||
torch >= 1.8.0 | ||
mira @ git+https://github.com/indralab/[email protected] | ||
chirho @ git+https://github.com/BasisResearch/chirho@f44731d416c20cbf1615147f5e0ba6ef3afed78d | ||
chirho @ git+https://github.com/BasisResearch/chirho@f3019d4b22f4e49261efbf8da90a30095af2afbc | ||
sympytorch | ||
torchdiffeq | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.