Skip to content

Commit

Permalink
Ensemble calibrate (#522)
Browse files Browse the repository at this point in the history
* first pass at calibrate ensemble

* lint

* add docstring
  • Loading branch information
SamWitty authored Mar 13, 2024
1 parent 8acd2df commit f5b1360
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 3 deletions.
168 changes: 168 additions & 0 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.interventional.ops import Intervention
from chirho.observational.handlers import condition
from chirho.observational.ops import observe

from pyciemss.compiled_dynamics import CompiledDynamics
from pyciemss.ensemble.compiled_dynamics import EnsembleCompiledDynamics
Expand Down Expand Up @@ -148,6 +149,173 @@ def wrapped_model():
)


@pyciemss_logging_wrapper
def ensemble_calibrate(
model_paths_or_jsons: List[Union[str, Dict]],
solution_mappings: List[
Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]
],
data_path: str,
*,
dirichlet_alpha: Optional[torch.Tensor] = None,
data_mapping: Dict[str, str] = {},
noise_model: str = "normal",
noise_model_kwargs: Dict[str, Any] = {"scale": 0.1},
solver_method: str = "dopri5",
solver_options: Dict[str, Any] = {},
start_time: float = 0.0,
num_iterations: int = 1000,
lr: float = 0.03,
verbose: bool = False,
num_particles: int = 1,
deterministic_learnable_parameters: List[str] = [],
progress_hook: Callable = lambda i, loss: None,
) -> Dict[str, Any]:
"""
Infer parameters for an ensemble of DynamicalSystem models conditional on data.
This uses variational inference with a mean-field variational family to infer the parameters of the model.
Args:
model_paths_or_jsons: List[Union[str, Dict]]
- A list of paths to AMR model files or JSONs containing models in AMR form.
solution_mappings: List[Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]]
- A list of functions that map the solution of each model to a common solution space.
- Each function takes in a dictionary of the form {state_variable_name: value}
and returns a dictionary of the same form.
data_path: str
- A path to the data file.
dirichlet_alpha: Optional[torch.Tensor]
- A tensor of shape (num_models,) containing the Dirichlet alpha values for the ensemble.
- A higher proportion of alpha values will result in higher weights for the corresponding models.
- A larger total alpha values will result in more certain priors.
- e.g. torch.tensor([1, 1, 1]) will result in a uniform prior over vectors of length 3 that sum to 1.
- e.g. torch.tensor([1, 2, 3]) will result in a prior that is biased towards the third model.
- If not provided, we will use a uniform Dirichlet prior.
data_mapping: Dict[str, str]
- A mapping from column names in the data file to state variable names in the model.
- keys: str name of column in dataset
- values: str name of state/observable in model
- If not provided, we will assume that the column names in the data file match the state variable names.
- Note: This mapping must match output of `solution_mappings`.
noise_model: str
- The noise model to use for the data.
- Currently we only support the normal distribution.
noise_model_kwargs: Dict[str, Any]
- Keyword arguments to pass to the noise model.
- Currently we only support the `scale` keyword argument for the normal distribution.
solver_method: str
- The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details.
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
- By default we set the `start_time` to be 0.
num_iterations: int
- The number of iterations to run the inference algorithm for.
lr: float
- The learning rate to use for the inference algorithm.
verbose: bool
- Whether to print out the loss at each iteration.
num_particles: int
- The number of particles to use for the inference algorithm.
deterministic_learnable_parameters: List[str]
- A list of parameter names that should be learned deterministically.
- By default, all parameters are learned probabilistically.
progress_hook: Callable[[int, float], None]
- A function that takes in the current iteration and the current loss.
- This is called at the beginning of each iteration.
- By default, this is a no-op.
- This can be used to implement custom progress bars.
Returns:
result: Dict[str, Any]
- Dictionary with the following key-value pairs.
- inferred_parameters: pyro.nn.PyroModule
- A Pyro module that contains the inferred parameters of the model.
- This can be passed to `ensemble_sample` to sample from the model conditional on the data.
- loss: float
- The final loss value of the approximate ELBO loss.
"""

pyro.clear_param_store()

if dirichlet_alpha is None:
dirichlet_alpha = torch.ones(len(model_paths_or_jsons))

model = EnsembleCompiledDynamics.load(
model_paths_or_jsons, dirichlet_alpha, solution_mappings
)

data_timepoints, data = load_data(data_path, data_mapping=data_mapping)

# Check that num_iterations is a positive integer
if not (isinstance(num_iterations, int) and num_iterations > 0):
raise ValueError("num_iterations must be a positive integer")

def autoguide(model):
guide = pyro.infer.autoguide.AutoGuideList(model)
guide.append(
pyro.infer.autoguide.AutoDelta(
pyro.poutine.block(model, expose=deterministic_learnable_parameters)
)
)

try:
mvn_guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal(
pyro.poutine.block(model, hide=deterministic_learnable_parameters)
)
mvn_guide._setup_prototype()
guide.append(mvn_guide)
except RuntimeError as re:
if (
re.args[0]
!= "AutoLowRankMultivariateNormal found no latent variables; Use an empty guide instead"
):
raise re

return guide

_noise_model = compile_noise_model(
noise_model,
vars=set(data.keys()),
**noise_model_kwargs,
)

_data = {f"{k}_noisy": v for k, v in data.items()}

def wrapped_model():
obs = condition(data=_data)(_noise_model)

with TorchDiffEq(method=solver_method, options=solver_options):
solution = model(
torch.as_tensor(start_time),
torch.as_tensor(data_timepoints[-1]),
logging_times=data_timepoints,
is_traced=True,
)

observe(solution, obs)

inferred_parameters = autoguide(wrapped_model)

optim = pyro.optim.Adam({"lr": lr})
loss = pyro.infer.Trace_ELBO(num_particles=num_particles)
svi = pyro.infer.SVI(wrapped_model, inferred_parameters, optim, loss=loss)

for i in range(num_iterations):
# Call a progress hook at the beginning of each iteration. This is used to implement custom progress bars.
progress_hook(i, loss)
loss = svi.step()
if verbose:
if i % 25 == 0:
print(f"iteration {i}: loss = {loss}")

return {"inferred_parameters": inferred_parameters, "loss": loss}


@pyciemss_logging_wrapper
def sample(
model_path_or_json: Union[str, Dict],
Expand Down
25 changes: 22 additions & 3 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@

from pyciemss.compiled_dynamics import CompiledDynamics
from pyciemss.integration_utils.observation import load_data
from pyciemss.interfaces import calibrate, ensemble_sample, optimize, sample
from pyciemss.interfaces import (
calibrate,
ensemble_calibrate,
ensemble_sample,
optimize,
sample,
)

from .fixtures import (
BADLY_FORMATTED_DATAFRAMES,
Expand Down Expand Up @@ -34,6 +40,15 @@ def dummy_ensemble_sample(model_path_or_json, *args, **kwargs):
return ensemble_sample(model_paths_or_jsons, solution_mappings, *args, **kwargs)


def dummy_ensemble_calibrate(model_path_or_json, *args, **kwargs):
model_paths_or_jsons = [model_path_or_json, model_path_or_json]
solution_mappings = [
lambda x: x,
lambda x: {k: v / 2 for k, v in x.items()},
]
return ensemble_calibrate(model_paths_or_jsons, solution_mappings, *args, **kwargs)


def setup_calibrate(model_fixture, start_time, end_time, logging_step_size):
if model_fixture.data_path is None:
pytest.skip("TODO: create temporary file")
Expand All @@ -55,6 +70,7 @@ def setup_calibrate(model_fixture, start_time, end_time, logging_step_size):


SAMPLE_METHODS = [sample, dummy_ensemble_sample]
CALIBRATE_METHODS = [calibrate, dummy_ensemble_calibrate]
INTERVENTION_TYPES = ["static", "dynamic"]
INTERVENTION_TARGETS = ["state", "parameter"]

Expand Down Expand Up @@ -291,11 +307,14 @@ def test_sample_with_multiple_parameter_interventions(
)


@pytest.mark.parametrize("calibrate_method", CALIBRATE_METHODS)
@pytest.mark.parametrize("model_fixture", MODELS)
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES)
def test_calibrate_no_kwargs(model_fixture, start_time, end_time, logging_step_size):
def test_calibrate_no_kwargs(
calibrate_method, model_fixture, start_time, end_time, logging_step_size
):
model_url = model_fixture.url
_, _, sample_args, sample_kwargs = setup_calibrate(
model_fixture, start_time, end_time, logging_step_size
Expand All @@ -310,7 +329,7 @@ def test_calibrate_no_kwargs(model_fixture, start_time, end_time, logging_step_s
}

with pyro.poutine.seed(rng_seed=0):
inferred_parameters = calibrate(*calibrate_args, **calibrate_kwargs)[
inferred_parameters = calibrate_method(*calibrate_args, **calibrate_kwargs)[
"inferred_parameters"
]

Expand Down

0 comments on commit f5b1360

Please sign in to comment.