Skip to content

Commit

Permalink
477 add progress hook to optimize (#592)
Browse files Browse the repository at this point in the history
* progress hook progress

* progress hook progress

* Lint

* moving callback for progress_hook

* linting

* fixing typing on array

* fixing typing on array

* linting

* linting

* moving progress hook to inside minimizer

* deleting commented code and changing typing

* fixing typing

* adding changes

* rerunning interfaces notebook

---------

Co-authored-by: Anirban Chaudhuri <[email protected]>
  • Loading branch information
sabinala and anirban-chaudhuri authored Aug 27, 2024
1 parent 059cbfb commit 83c974a
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 28 deletions.
42 changes: 22 additions & 20 deletions docs/source/interfaces.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"import os\n",
"import pyciemss\n",
"import torch\n",
"from tqdm import tqdm\n",
"import pandas as pd\n",
"import numpy as np\n",
"from typing import Dict, List, Callable\n",
Expand Down Expand Up @@ -2523,7 +2524,10 @@
"end_time = 40.0\n",
"logging_step_size = 1.0\n",
"observed_params = [\"I_state\"]\n",
"intervened_params = [\"p_cbeta\"]"
"intervened_params = [\"p_cbeta\"]\n",
"\n",
"maxiter=5\n",
"maxfeval=10"
]
},
{
Expand All @@ -2535,11 +2539,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Anirban\\Documents\\GitHub\\pyciemss\\pyciemss\\interfaces.py:870: UserWarning: risk_bound is not a List. Forcing it to be a list.\n",
" warnings.warn(\"risk_bound is not a List. Forcing it to be a list.\")\n",
"C:\\Users\\Anirban\\Documents\\GitHub\\pyciemss\\pyciemss\\interfaces.py:873: UserWarning: qoi is not a List. Forcing it to be a list.\n",
" warnings.warn(\"qoi is not a List. Forcing it to be a list.\")\n",
" 40%|████ | 8/20 [00:04<00:07, 1.62it/s]"
" 75%|████████████████████████████████████ | 45/60 [00:06<00:01, 9.23it/s]"
]
},
{
Expand All @@ -2551,9 +2551,9 @@
" success: True\n",
" fun: 0.14687910929135808\n",
" x: [ 2.032e-01]\n",
" nit: 1\n",
" minimization_failures: 0\n",
" nfev: 8\n",
" nit: 5\n",
" minimization_failures: 4\n",
" nfev: 46\n",
" lowest_optimization_result: message: Optimization terminated successfully.\n",
" success: True\n",
" status: 1\n",
Expand All @@ -2562,13 +2562,6 @@
" nfev: 8\n",
" maxcv: 0.0}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
Expand All @@ -2586,6 +2579,14 @@
" param_value=[lambda x: torch.tensor(x)],\n",
" start_time=intervention_time,\n",
")\n",
"\n",
"# Progress bar\n",
"pbar = tqdm(total=maxfeval * (maxiter + 1))\n",
"\n",
"# update_progress(coordinate, function_min, accept)\n",
"def update_progress(xk):\n",
" pbar.update(1)\n",
"\n",
"opt_result = pyciemss.optimize(\n",
" model3,\n",
" end_time,\n",
Expand All @@ -2598,10 +2599,11 @@
" bounds_interventions=bounds_interventions,\n",
" start_time=0.0,\n",
" n_samples_ouu=int(1e2),\n",
" maxiter=0,\n",
" maxfeval=20,\n",
" maxiter=maxiter,\n",
" maxfeval=maxfeval,\n",
" solver_method=\"euler\",\n",
" solver_options={\"step_size\": logging_step_size / 2},\n",
" solver_options={\"step_size\": logging_step_size/2},\n",
" progress_hook=update_progress,\n",
")\n",
"print(f\"Optimal policy:\", opt_result[\"policy\"])\n",
"print(opt_result)"
Expand Down Expand Up @@ -3339,7 +3341,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.9"
}
},
"nbformat": 4,
Expand Down
6 changes: 6 additions & 0 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ def optimize(
maxfeval: int = 25,
verbose: bool = False,
roundup_decimal: int = 4,
progress_hook: Callable[[torch.Tensor], None] = lambda x: None,
) -> Dict[str, Any]:
r"""
Load a model from a file, compile it into a probabilistic program, and optimize under uncertainty with risk-based
Expand Down Expand Up @@ -858,6 +859,10 @@ def optimize(
- Whether to print out the optimization under uncertainty progress.
roundup_decimal: int
- Number of significant digits for the optimal policy.
progress_hook: progress_hook: Callable[[torch.Tensor], None],
- A callback function that takes in the current parameter vector as a tensor.
If the function returns StopIteration, the minimization will terminate.
- This can be used to implement custom progress bars and/or early stopping criteria.
Returns:
result: Dict[str, Any]
Expand Down Expand Up @@ -946,6 +951,7 @@ def objfun_penalty(x):
maxiter=maxiter,
maxfeval=maxfeval,
u_bounds=bounds_np,
progress_hook=progress_hook,
).solve()

# Rounding up to given number of decimal places
Expand Down
13 changes: 6 additions & 7 deletions pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.interventional.ops import Intervention
from scipy.optimize import basinhopping
from tqdm import tqdm

from pyciemss.integration_utils.intervention_builder import (
combine_static_parameter_interventions,
Expand Down Expand Up @@ -181,34 +180,34 @@ def __init__(
maxfeval: int = 100,
maxiter: int = 100,
u_bounds: np.ndarray = np.atleast_2d([[0], [1]]),
progress_hook: Callable[
[torch.Tensor], None
] = lambda x: None, # update_progress
):
self.x0 = np.squeeze(np.array([x0]))
self.objfun = objfun
self.constraints = constraints
self.maxiter = maxiter
self.maxfeval = maxfeval
self.u_bounds = u_bounds
self.progress_hook = progress_hook
# self.kwargs = kwargs

def solve(self):
pbar = tqdm(total=self.maxfeval * (self.maxiter + 1))

def update_progress(xk):
pbar.update(1)

# wrapper around SciPy optimizer(s)
# rhobeg is set to 10% of longest euclidean distance
minimizer_kwargs = dict(
constraints=self.constraints,
method="COBYLA",
tol=1e-5,
callback=update_progress,
options={
"rhobeg": 0.1
* np.linalg.norm(self.u_bounds[1, :] - self.u_bounds[0, :]),
"disp": False,
"maxiter": self.maxfeval,
"catol": 1e-5,
},
callback=self.progress_hook,
)
take_step = RandomDisplacementBounds(self.u_bounds[0, :], self.u_bounds[1, :])

Expand Down
19 changes: 18 additions & 1 deletion tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,14 +563,27 @@ def test_output_format(
def test_optimize(model_fixture, start_time, end_time, num_samples):
logging_step_size = 1.0
model_url = model_fixture.url

class TestProgressHook:
def __init__(self):
self.result_x = []

def __call__(self, x):
# Log the iteration number
self.result_x.append(x)
print(f"Result: {self.result_x}")

progress_hook = TestProgressHook()

optimize_kwargs = {
**model_fixture.optimize_kwargs,
"solver_method": "euler",
"solver_options": {"step_size": 0.1},
"start_time": start_time,
"n_samples_ouu": int(5),
"n_samples_ouu": int(2),
"maxiter": 1,
"maxfeval": 2,
"progress_hook": progress_hook,
}
bounds_interventions = optimize_kwargs["bounds_interventions"]
opt_result = optimize(
Expand Down Expand Up @@ -622,6 +635,10 @@ def test_optimize(model_fixture, start_time, end_time, num_samples):
intervened_result_subset, start_time, end_time, logging_step_size, num_samples
)

assert len(progress_hook.result_x) <= (
(optimize_kwargs["maxfeval"] + 1) * (optimize_kwargs["maxiter"] + 1)
)


@pytest.mark.parametrize("model_fixture", MODELS)
@pytest.mark.parametrize("bad_num_iterations", NON_POS_INTS)
Expand Down

0 comments on commit 83c974a

Please sign in to comment.