Skip to content

Commit

Permalink
Update optimize to be more computationally efficient and add separa…
Browse files Browse the repository at this point in the history
…te notebooks for `optimize` and `sample` (#529)

* Create optimize_interface.ipynb

* Update optimize_interface.ipynb

* Optimize interface with multiple interventions

* Update optimize_interface.ipynb

* added SEIRHD model to notebook

* Update optimize_interface.ipynb

* Updating intervention builder to handle multiple interventions at the same time

* Update optimize_interface.ipynb

* Update interfaces.py

* Update optimize_interface.ipynb

* Notebook testing different solver options to provide more details to users

* Notebook updates

* Linting

* Update optimize_interface.ipynb

* Update optimize_interface.ipynb

* Updated intervention_builder to handle multiple intervention param values

* type annotation for empty dictionary

* Adding smoke_test to optimize interface parameters

* Update optimize_interface.ipynb

* Added seed

* Updated to use rk4 in optimization

* Adding penalties to force the optimizer to find a solution within the bounds

Avoids simulating the model if interventions are out of bounds

* Linting

* Update optimize_interface.ipynb

* Update interfaces.ipynb

* Adding `smoke_test` to `optimize` interface parameters

* copying interfaces notebook from main to avoid merge conflicts

* Updated ouu to use parallel when sampling

Significantly improves speed

* Linting

* Update `RandomDisplacementBounds`

* Lint

* remove extra print statements

* Update optimize_interface.ipynb

* Removing `torch.squeeze()` from `sample`

* Update sample_interface_solvers.ipynb
  • Loading branch information
anirban-chaudhuri authored Mar 18, 2024
1 parent 3703778 commit 1fc62b0
Show file tree
Hide file tree
Showing 7 changed files with 1,671 additions and 59 deletions.
989 changes: 989 additions & 0 deletions docs/source/optimize_interface.ipynb

Large diffs are not rendered by default.

562 changes: 562 additions & 0 deletions docs/source/sample_interface_solvers.ipynb

Large diffs are not rendered by default.

42 changes: 28 additions & 14 deletions pyciemss/integration_utils/intervention_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,30 @@ def param_value_objective(
start_time: List[torch.Tensor],
param_value: List[Intervention] = [None],
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
if len(param_value) < len(param_name) and param_value[0] is None:
param_value = [None for _ in param_name]
for count in range(len(param_name)):
if param_value[count] is None:
if not callable(param_value[count]):
param_value[count] = lambda y: torch.tensor(y)

def intervention_generator(
x: torch.Tensor,
) -> Dict[float, Dict[str, Intervention]]:
static_parameter_interventions = {}
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for count in range(len(param_name)):
if param_value[count] is None:
if not callable(param_value[count]):
param_value[count] = lambda x: torch.tensor(x)
static_parameter_interventions.update(
{
start_time[count].item(): {
param_name[count]: param_value[count](x[count].item())
if start_time[count].item() in static_parameter_interventions:
static_parameter_interventions[start_time[count].item()].update(
{param_name[count]: param_value[count](x[count].item())}
)
else:
static_parameter_interventions.update(
{
start_time[count].item(): {
param_name[count]: param_value[count](x[count].item())
}
}
}
)
)
return static_parameter_interventions

return intervention_generator
Expand All @@ -35,11 +44,16 @@ def start_time_objective(
def intervention_generator(
x: torch.Tensor,
) -> Dict[float, Dict[str, Intervention]]:
static_parameter_interventions = {}
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for count in range(len(param_name)):
static_parameter_interventions.update(
{x[count].item(): {param_name[count]: param_value[count]}}
)
if x[count].item() in static_parameter_interventions:
static_parameter_interventions[x[count].item()].update(
{param_name[count]: param_value[count]}
)
else:
static_parameter_interventions.update(
{x[count].item(): {param_name[count]: param_value[count]}}
)
return static_parameter_interventions

return intervention_generator
23 changes: 15 additions & 8 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def wrapped_model():
for k, vals in samples.items():
if "_state" in k:
# qoi is assumed to be the last day of simulation
qoi_sample = vals[:, -1]
qoi_sample = vals.detach().numpy()[:, -1]
sq_est = alpha_superquantile(qoi_sample, alpha=alpha)
risk_results.update({k: {"risk": [sq_est], "qoi": qoi_sample}})

Expand Down Expand Up @@ -762,7 +762,7 @@ def optimize(
solver_options: Dict[str, Any] = {},
start_time: float = 0.0,
inferred_parameters: Optional[pyro.nn.PyroModule] = None,
n_samples_ouu: int = int(1e2),
n_samples_ouu: int = int(1e3),
maxiter: int = 5,
maxfeval: int = 25,
verbose: bool = False,
Expand Down Expand Up @@ -852,6 +852,8 @@ def optimize(
guide=inferred_parameters,
solver_method=solver_method,
solver_options=solver_options,
u_bounds=bounds_np,
risk_bound=risk_bound,
)

# Run one sample to estimate model evaluation time
Expand All @@ -877,6 +879,8 @@ def optimize(
guide=inferred_parameters,
solver_method=solver_method,
solver_options=solver_options,
u_bounds=bounds_np,
risk_bound=risk_bound,
)
# Define constraints >= 0
constraints = (
Expand All @@ -893,10 +897,18 @@ def optimize(
print(
f"Estimated wait time {time_per_eval*n_samples_ouu*(maxiter+1)*maxfeval:.1f} seconds..."
)

# Updating the objective function to penalize out of bounds interventions
def objfun_penalty(x):
if np.any(x - u_min < 0) or np.any(u_max - x < 0):
return objfun(x) + max(5 * np.abs(objfun(x)), 5.0)
else:
return objfun(x)

start_time = time.time()
opt_results = solveOUU(
x0=initial_guess_interventions,
objfun=objfun,
objfun=objfun_penalty,
constraints=constraints,
maxiter=maxiter,
maxfeval=maxfeval,
Expand All @@ -917,11 +929,6 @@ def round_up(num, dec=roundup_decimal):
)
print(f"Optimal policy:\t{opt_results.x}")

# Check for some interventions that lead to no feasible solutions
if opt_results.x < 0:
if verbose:
print("No solution found")

ouu_results = {
"policy": torch.tensor(opt_results.x),
"OptResults": opt_results,
Expand Down
1 change: 0 additions & 1 deletion pyciemss/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def markov_kernel(
raise NotImplementedError

def forward(self, state: Dict[str, torch.Tensor]) -> None:

if self.observables is not None:
for k, v in self.observables(state).items():
state[k] = v
Expand Down
112 changes: 77 additions & 35 deletions pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import warnings
from typing import Any, Callable, Dict, List, Tuple

import numpy as np
Expand All @@ -9,7 +10,10 @@
from scipy.optimize import basinhopping
from tqdm import tqdm

from pyciemss.interruptions import StaticParameterIntervention
from pyciemss.interruptions import (
ParameterInterventionTracer,
StaticParameterIntervention,
)
from pyciemss.ouu.risk_measures import alpha_superquantile


Expand All @@ -28,11 +32,24 @@ def __init__(self, xmin, xmax, stepsize=None):
self.stepsize = 0.3 * np.linalg.norm(xmax - xmin)

def __call__(self, x):
xnew = np.clip(
x + np.random.uniform(-self.stepsize, self.stepsize, np.shape(x)),
self.xmin,
self.xmax,
)
if np.any(x - self.xmax >= 0.0):
xnew = np.clip(
x + np.random.uniform(-self.stepsize, 0, np.shape(x)),
self.xmin,
self.xmax,
)
elif np.any(x - self.xmin <= 0.0):
xnew = np.clip(
x + np.random.uniform(0, self.stepsize, np.shape(x)),
self.xmin,
self.xmax,
)
else:
xnew = np.clip(
x + np.random.uniform(-self.stepsize, self.stepsize, np.shape(x)),
self.xmin,
self.xmax,
)
return xnew


Expand All @@ -55,6 +72,8 @@ def __init__(
guide=None,
solver_method: str = "dopri5",
solver_options: Dict[str, Any] = {},
u_bounds: np.ndarray = np.atleast_2d([[0], [1]]),
risk_bound: float = 0.0,
):
self.model = model
self.interventions = interventions
Expand All @@ -70,43 +89,66 @@ def __init__(
self.logging_times = torch.arange(
start_time + logging_step_size, end_time, logging_step_size
)
self.u_bounds = u_bounds
self.risk_bound = risk_bound # used for defining penalty
warnings.simplefilter("always", UserWarning)

def __call__(self, x):
# Apply intervention and perform forward uncertainty propagation
samples = self.propagate_uncertainty(x)
# Compute quanity of interest
sample_qoi = self.qoi(samples)
# Estimate risk
return self.risk_measure(sample_qoi)
if np.any(x - self.u_bounds[0, :] < 0.0) or np.any(
self.u_bounds[1, :] - x < 0.0
):
warnings.warn(
"Selected interventions are out of bounds. Will use a penalty instead of estimating risk."
)
risk_estimate = max(
2 * self.risk_bound, 10.0
) # used as a penalty and the model is not run
else:
# Apply intervention and perform forward uncertainty propagation
samples = self.propagate_uncertainty(x)
# Compute quanity of interest
sample_qoi = self.qoi(samples)
# Estimate risk
risk_estimate = self.risk_measure(sample_qoi)
return risk_estimate

def propagate_uncertainty(self, x):
"""
Perform forward uncertainty propagation.
"""
pyro.set_rng_seed(0)
x = np.atleast_1d(x)
static_parameter_interventions = self.interventions(torch.from_numpy(x))
static_parameter_intervention_handlers = [
StaticParameterIntervention(time, dict(**static_intervention_assignment))
for time, static_intervention_assignment in static_parameter_interventions.items()
]

def wrapped_model():
with TorchDiffEq(method=self.solver_method, options=self.solver_options):
with contextlib.ExitStack() as stack:
for handler in static_parameter_intervention_handlers:
stack.enter_context(handler)
self.model(
torch.as_tensor(self.start_time),
torch.as_tensor(self.end_time),
logging_times=self.logging_times,
is_traced=True,
with pyro.poutine.seed(rng_seed=0):
with torch.no_grad():
x = np.atleast_1d(x)
static_parameter_interventions = self.interventions(torch.from_numpy(x))
static_parameter_intervention_handlers = [
StaticParameterIntervention(
time, dict(**static_intervention_assignment)
)

# Sample from intervened model
samples = pyro.infer.Predictive(
wrapped_model, guide=self.guide, num_samples=self.num_samples
)()
for time, static_intervention_assignment in static_parameter_interventions.items()
]

def wrapped_model():
with ParameterInterventionTracer():
with TorchDiffEq(
method=self.solver_method, options=self.solver_options
):
with contextlib.ExitStack() as stack:
for handler in static_parameter_intervention_handlers:
stack.enter_context(handler)
self.model(
torch.as_tensor(self.start_time),
torch.as_tensor(self.end_time),
logging_times=self.logging_times,
is_traced=True,
)

# Sample from intervened model
samples = pyro.infer.Predictive(
wrapped_model,
guide=self.guide,
num_samples=self.num_samples,
parallel=True,
)()
return samples


Expand Down
1 change: 0 additions & 1 deletion tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ def test_sample_with_multiple_parameter_interventions(
logging_step_size,
num_samples,
):

model_url = model_fixture.url
model = CompiledDynamics.load(model_url)

Expand Down

0 comments on commit 1fc62b0

Please sign in to comment.