diff --git a/pyciemss/interfaces.py b/pyciemss/interfaces.py index 23756c8a3..4532bd918 100644 --- a/pyciemss/interfaces.py +++ b/pyciemss/interfaces.py @@ -1,144 +1,139 @@ -import functools -from typing import Generic, Optional, TypeVar - -# By convention we use "T" to denote the type of the dynamical system, e.g. `ODE`, `PDE`, or `SDE`. -T = TypeVar("T") - - -class DynamicalSystem(Generic[T]): - """ - A dynamical system is a model of a system that evolves over time. - """ - - pass - - -class Intervention(Generic[T]): - """ - An intervention is a change to a dynamical system that is not a change to the parameters. - """ - - pass - - -class Data(Generic[T]): - """ - Data is a collection of observations of the dynamical system. - """ - - pass - - -class InferredParameters(Generic[T]): - """ - InferredParameters are the parameters of the dynamical system that are inferred from data. - """ - - pass - - -class Simulation(Generic[T]): - """ - A simulation is a collection of trajectories of a dynamical system. - """ - - pass - - -class ObjectiveFunction(Generic[T]): - """ - An objective function is a function that is optimized to infer parameters from data. - """ - - pass - - -class Constraints(Generic[T]): - """ - Constraints are constraints on the parameters of the dynamical system for optimization. - """ - - pass - - -class OptimizationAlgorithm(Generic[T]): - """ - An optimization algorithm is an algorithm that is used to optimize the objective function subject to constraints. - """ - - pass - - -class OptimizationResult(Generic[T]): - """ - An optimization result is the result of optimizing the objective function subject to constraints. - """ - - pass - - -@functools.singledispatch -def setup_model(model: DynamicalSystem[T], *args, **kwargs) -> DynamicalSystem[T]: - """ - Instatiate a model for a particular configuration of initial conditions, boundary conditions, logging events, etc. - """ - raise NotImplementedError - - -@functools.singledispatch -def reset_model(model: DynamicalSystem[T], *args, **kwargs) -> DynamicalSystem[T]: - """ - Reset a model to its initial state. - reset_model * setup_model = id - """ - raise NotImplementedError - - -@functools.singledispatch -def intervene( - model: DynamicalSystem[T], intervention: Intervention[T], *args, **kwargs -) -> DynamicalSystem[T]: - """ - `intervene(model, intervention)` returns a new model where the intervention has been applied. - """ - raise NotImplementedError - - -@functools.singledispatch -def calibrate( - model: DynamicalSystem[T], data: Data[T], *args, **kwargs -) -> InferredParameters[T]: - """ - Infer parameters for a DynamicalSystem model conditional on data. - This is typically done using a variational approximation. - """ - raise NotImplementedError - - -@functools.singledispatch -def simulate( - model: DynamicalSystem[T], - inferred_parameters: Optional[InferredParameters[T]] = None, - *args, - **kwargs -) -> Simulation[T]: - """ - Simulate trajectories from a given `model`, conditional on specified `inferred_parameters` distribution. - If `inferred_parameters` is not given, this will sample from the prior distribution. - """ - raise NotImplementedError - - -@functools.singledispatch -def optimize( - model: DynamicalSystem[T], - objective_function: ObjectiveFunction[T], - constraints: Constraints[T], - optimization_algorithm: OptimizationAlgorithm[T], - *args, - **kwargs -) -> OptimizationResult[T]: - """ - Optimize the objective function subject to the constraints. - """ - raise NotImplementedError +import contextlib +from typing import Any, Callable, Dict, Optional, Union + +import pyro +import torch +from chirho.dynamical.handlers import ( + DynamicIntervention, + InterruptionEventLoop, + LogTrajectory, + StaticIntervention, +) +from chirho.dynamical.handlers.solver import TorchDiffEq +from chirho.dynamical.ops import State + +from pyciemss.compiled_dynamics import CompiledDynamics +from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper + + +@pyciemss_logging_wrapper +def sample( + model_path_or_json: Union[str, Dict], + end_time: float, + logging_step_size: float, + num_samples: int, + *, + solver_method="dopri5", + solver_options: Dict[str, Any] = {}, + start_time: float = 0.0, + inferred_parameters: Optional[pyro.nn.PyroModule] = None, + static_interventions: Dict[float, Dict[str, torch.Tensor]] = {}, + dynamic_interventions: Dict[ + Callable[[Dict[str, torch.Tensor]], torch.Tensor], Dict[str, torch.Tensor] + ] = {}, +) -> Dict[str, torch.Tensor]: + """ + Load a model from a file, compile it into a probabilistic program, and sample from it. + + Args: + model_path_or_json: Union[str, Dict] + - A path to a AMR model file or JSON containing a model in AMR form. + end_time: float + - The end time of the sampled simulation. + logging_step_size: float + - The step size to use for logging the trajectory. + num_samples: int + - The number of samples to draw from the model. + interventions: Optional[Iterable[Tuple[float, str, float]]] + - A list of interventions to apply to the model. + Each intervention is a tuple of the form (time, parameter_name, value). + solver_method: str + - The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details. + - If performance is incredibly slow, we suggest using `euler` to debug. + If using `euler` results in faster simulation, the issue is likely that the model is stiff. + solver_options: Dict[str, Any] + - Options to pass to the solver. See torchdiffeq' `odeint` method for more details. + start_time: float + - The start time of the model. This is used to align the `start_state` from the + AMR model with the simulation timepoints. + - By default we set the `start_time` to be 0. + inferred_parameters: + - A Pyro module that contains the inferred parameters of the model. + This is typically the result of `calibrate`. + - If not provided, we will use the default values from the AMR model. + static_interventions: Dict[float, Dict[str, torch.Tensor]] + - A dictionary of static interventions to apply to the model. + - Each key is the time at which the intervention is applied. + - Each value is a dictionary of the form {state_variable_name: value}. + dynamic_interventions: Dict[Callable[[Dict[str, torch.Tensor]], torch.Tensor], Dict[str, torch.Tensor]] + - A dictionary of dynamic interventions to apply to the model. + - Each key is a function that takes in the current state of the model and returns a tensor. + When this function crosses 0, the dynamic intervention is applied. + - Each value is a dictionary of the form {state_variable_name: value}. + + Returns: + result: Dict[str, torch.Tensor] + - Dictionary of outputs from the model. + - Each key is the name of a parameter or state variable in the model. + - Each value is a tensor of shape (num_samples, num_timepoints) for state variables + and (num_samples,) for parameters. + """ + + model = CompiledDynamics.load(model_path_or_json) + + timespan = torch.arange(start_time, end_time, logging_step_size) + + static_intervention_handlers = [ + StaticIntervention(time, State(**static_intervention_assignment)) + for time, static_intervention_assignment in static_interventions.items() + ] + dynamic_intervention_handlers = [ + DynamicIntervention(event_fn, State(**dynamic_intervention_assignment)) + for event_fn, dynamic_intervention_assignment in dynamic_interventions.items() + ] + + def wrapped_model(): + with LogTrajectory(timespan) as lt: + with InterruptionEventLoop(): + with contextlib.ExitStack() as stack: + for handler in ( + static_intervention_handlers + dynamic_intervention_handlers + ): + stack.enter_context(handler) + model( + torch.as_tensor(start_time), + torch.tensor(end_time), + TorchDiffEq(method=solver_method, options=solver_options), + ) + # Adding deterministic nodes to the model so that we can access the trajectory in the Predictive object. + [pyro.deterministic(f"state_{k}", v) for k, v in lt.trajectory.items()] + + return pyro.infer.Predictive( + wrapped_model, guide=inferred_parameters, num_samples=num_samples + )() + + +# # TODO +# def calibrate( +# model: CompiledDynamics, data: Data, *args, **kwargs +# ) -> pyro.nn.PyroModule: +# """ +# Infer parameters for a DynamicalSystem model conditional on data. +# This is typically done using a variational approximation. +# """ +# raise NotImplementedError + + +# # TODO +# def optimize( +# model: CompiledDynamics, +# objective_function: ObjectiveFunction, +# constraints: Constraints, +# optimization_algorithm: OptimizationAlgorithm, +# *args, +# **kwargs +# ) -> OptimizationResult: +# """ +# Optimize the objective function subject to the constraints. +# """ +# raise NotImplementedError diff --git a/setup.cfg b/setup.cfg index 2f167065d..d8c222fe5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ install_requires = jupyter torch >= 1.8.0 mira @ git+https://github.com/indralab/mira.git@0.2.0 - chirho @ git+https://github.com/BasisResearch/chirho@f44731d416c20cbf1615147f5e0ba6ef3afed78d + chirho @ git+https://github.com/BasisResearch/chirho@f3019d4b22f4e49261efbf8da90a30095af2afbc sympytorch torchdiffeq diff --git a/tests/model_fixtures.py b/tests/fixtures.py similarity index 51% rename from tests/model_fixtures.py rename to tests/fixtures.py index 191357664..e934e56b7 100644 --- a/tests/model_fixtures.py +++ b/tests/fixtures.py @@ -1,5 +1,9 @@ +from typing import Dict, TypeVar + import torch +T = TypeVar("T") + # SEE https://github.com/DARPA-ASKEM/Model-Representations/issues/62 for discussion of valid models. PETRI_URLS = [ @@ -23,5 +27,51 @@ MODEL_URLS = PETRI_URLS + REGNET_URLS + STOCKFLOW_URLS -START_TIMES = [torch.tensor(0.0), torch.tensor(1.0), torch.tensor(2.0)] -END_TIMES = [torch.tensor(3.0), torch.tensor(4.0), torch.tensor(5.0)] +START_TIMES = [0.0, 1.0, 2.0] +END_TIMES = [3.0, 4.0, 5.0] + +LOGGING_STEP_SIZES = [0.1] + +NUM_SAMPLES = [2] + + +def check_keys_match(obj1: Dict[str, T], obj2: Dict[str, T]): + assert set(obj1.keys()) == set(obj2.keys()), "Objects have different variables." + return True + + +def check_states_match_in_all_but_values( + traj1: Dict[str, torch.Tensor], traj2: Dict[str, torch.Tensor] +): + assert check_keys_match(traj1, traj2) + + for k in traj1.keys(): + if k[:5] == "state": + assert not torch.allclose( + traj2[k], traj1[k] + ), f"Trajectories are identical in state trajectory of variable {k}, but should differ." + + return True + + +def check_result_sizes( + traj: Dict[str, torch.Tensor], + start_time: float, + end_time: float, + logging_step_size: float, + num_samples: int, +): + for k, v in traj.items(): + assert isinstance(k, str) + assert isinstance(v, torch.Tensor) + + if k[:5] == "state": + assert v.shape == ( + num_samples, + len(torch.arange(start_time, end_time, logging_step_size)) + - 1, # Does not include start_time + ) + else: + assert v.shape == (num_samples,) + + return True diff --git a/tests/test_compiled_dynamics.py b/tests/test_compiled_dynamics.py index f4d38ee0d..00d4f8093 100644 --- a/tests/test_compiled_dynamics.py +++ b/tests/test_compiled_dynamics.py @@ -3,10 +3,12 @@ import pytest import requests +import torch +from chirho.dynamical.ops import State from pyciemss.compiled_dynamics import CompiledDynamics -from .model_fixtures import END_TIMES, MODEL_URLS, START_TIMES +from .fixtures import END_TIMES, MODEL_URLS, START_TIMES @pytest.mark.parametrize("url", MODEL_URLS) @@ -16,8 +18,8 @@ def test_compiled_dynamics_load_url(url, start_time, end_time): model = CompiledDynamics.load(url) assert isinstance(model, CompiledDynamics) - simulation = model(start_time, end_time) - assert simulation is not None + simulation = model(torch.as_tensor(start_time), torch.as_tensor(end_time)) + assert isinstance(simulation, State) @pytest.mark.parametrize("url", MODEL_URLS) @@ -33,8 +35,8 @@ def test_compiled_dynamics_load_path(url, start_time, end_time): model = CompiledDynamics.load(tf.name) assert isinstance(model, CompiledDynamics) - simulation = model(start_time, end_time) - assert simulation is not None + simulation = model(torch.as_tensor(start_time), torch.as_tensor(end_time)) + assert isinstance(simulation, State) @pytest.mark.parametrize("url", MODEL_URLS) @@ -48,5 +50,5 @@ def test_compiled_dynamics_load_json(url, start_time, end_time): model = CompiledDynamics.load(model_json) assert isinstance(model, CompiledDynamics) - simulation = model(start_time, end_time) - assert simulation is not None + simulation = model(torch.as_tensor(start_time), torch.as_tensor(end_time)) + assert isinstance(simulation, State) diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py new file mode 100644 index 000000000..4ea1d6957 --- /dev/null +++ b/tests/test_interfaces.py @@ -0,0 +1,164 @@ +import pytest +import torch + +from pyciemss.compiled_dynamics import CompiledDynamics +from pyciemss.interfaces import sample + +from .fixtures import ( + END_TIMES, + LOGGING_STEP_SIZES, + MODEL_URLS, + NUM_SAMPLES, + START_TIMES, + check_result_sizes, + check_states_match_in_all_but_values, +) + + +@pytest.mark.parametrize("url", MODEL_URLS) +@pytest.mark.parametrize("start_time", START_TIMES) +@pytest.mark.parametrize("end_time", END_TIMES) +@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) +@pytest.mark.parametrize("num_samples", NUM_SAMPLES) +def test_sample_no_interventions( + url, start_time, end_time, logging_step_size, num_samples +): + result = sample( + url, end_time, logging_step_size, num_samples, start_time=start_time + ) + assert isinstance(result, dict) + check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) + + +@pytest.mark.parametrize("url", MODEL_URLS) +@pytest.mark.parametrize("start_time", START_TIMES) +@pytest.mark.parametrize("end_time", END_TIMES) +@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) +@pytest.mark.parametrize("num_samples", NUM_SAMPLES) +def test_sample_with_static_interventions( + url, start_time, end_time, logging_step_size, num_samples +): + model = CompiledDynamics.load(url) + + initial_state = model.initial_state() + intervened_state_1 = {k: v + 1 for k, v in initial_state.items()} + intervened_state_2 = {k: v + 2 for k, v in initial_state.items()} + + intervention_time_1 = (end_time + start_time) / 2 # Midpoint + intervention_time_2 = (end_time + intervention_time_1) / 2 # 3/4 point + static_interventions = { + intervention_time_1: intervened_state_1, + intervention_time_2: intervened_state_2, + } + + intervened_result = sample( + url, + end_time, + logging_step_size, + num_samples, + start_time=start_time, + static_interventions=static_interventions, + ) + + result = sample( + url, end_time, logging_step_size, num_samples, start_time=start_time + ) + + check_states_match_in_all_but_values(result, intervened_result) + check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) + check_result_sizes( + intervened_result, start_time, end_time, logging_step_size, num_samples + ) + + +@pytest.mark.parametrize("url", MODEL_URLS) +@pytest.mark.parametrize("start_time", START_TIMES) +@pytest.mark.parametrize("end_time", END_TIMES) +@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) +@pytest.mark.parametrize("num_samples", NUM_SAMPLES) +def test_sample_with_dynamic_interventions( + url, start_time, end_time, logging_step_size, num_samples +): + model = CompiledDynamics.load(url) + + initial_state = model.initial_state() + intervened_state_1 = {k: v + 1 for k, v in initial_state.items()} + intervened_state_2 = {k: v + 2 for k, v in initial_state.items()} + + intervention_time_1 = (end_time + start_time) / 2 # Midpoint + intervention_time_2 = (end_time + intervention_time_1) / 2 # 3/4 point + + def intervention_event_fn_1(time: torch.Tensor, *args, **kwargs): + return time - intervention_time_1 + + def intervention_event_fn_2(time: torch.Tensor, *args, **kwargs): + return time - intervention_time_2 + + dynamic_interventions = { + intervention_event_fn_1: intervened_state_1, + intervention_event_fn_2: intervened_state_2, + } + + intervened_result = sample( + url, + end_time, + logging_step_size, + num_samples, + start_time=start_time, + dynamic_interventions=dynamic_interventions, + ) + + result = sample( + url, end_time, logging_step_size, num_samples, start_time=start_time + ) + + check_states_match_in_all_but_values(result, intervened_result) + check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) + check_result_sizes( + intervened_result, start_time, end_time, logging_step_size, num_samples + ) + + +@pytest.mark.parametrize("url", MODEL_URLS) +@pytest.mark.parametrize("start_time", START_TIMES) +@pytest.mark.parametrize("end_time", END_TIMES) +@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) +@pytest.mark.parametrize("num_samples", NUM_SAMPLES) +def test_sample_with_static_and_dynamic_interventions( + url, start_time, end_time, logging_step_size, num_samples +): + model = CompiledDynamics.load(url) + + initial_state = model.initial_state() + intervened_state_1 = {k: v + 1 for k, v in initial_state.items()} + intervened_state_2 = {k: v + 2 for k, v in initial_state.items()} + + intervention_time_1 = (end_time + start_time) / 2 # Midpoint + intervention_time_2 = (end_time + intervention_time_1) / 2 # 3/4 point + + def intervention_event_fn_1(time: torch.Tensor, *args, **kwargs): + return time - intervention_time_1 + + dynamic_interventions = {intervention_event_fn_1: intervened_state_1} + + static_interventions = {intervention_time_2: intervened_state_2} + + intervened_result = sample( + url, + end_time, + logging_step_size, + num_samples, + start_time=start_time, + static_interventions=static_interventions, + dynamic_interventions=dynamic_interventions, + ) + + result = sample( + url, end_time, logging_step_size, num_samples, start_time=start_time + ) + + check_states_match_in_all_but_values(result, intervened_result) + check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) + check_result_sizes( + intervened_result, start_time, end_time, logging_step_size, num_samples + )