Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

477 add progress hook to optimize #592

Merged
merged 15 commits into from
Aug 27, 2024
42 changes: 22 additions & 20 deletions docs/source/interfaces.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"import os\n",
"import pyciemss\n",
"import torch\n",
"from tqdm import tqdm\n",
"import pandas as pd\n",
"import numpy as np\n",
"from typing import Dict, List, Callable\n",
Expand Down Expand Up @@ -2523,7 +2524,10 @@
"end_time = 40.0\n",
"logging_step_size = 1.0\n",
"observed_params = [\"I_state\"]\n",
"intervened_params = [\"p_cbeta\"]"
"intervened_params = [\"p_cbeta\"]\n",
"\n",
"maxiter=5\n",
"maxfeval=10"
]
},
{
Expand All @@ -2535,11 +2539,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Anirban\\Documents\\GitHub\\pyciemss\\pyciemss\\interfaces.py:870: UserWarning: risk_bound is not a List. Forcing it to be a list.\n",
" warnings.warn(\"risk_bound is not a List. Forcing it to be a list.\")\n",
"C:\\Users\\Anirban\\Documents\\GitHub\\pyciemss\\pyciemss\\interfaces.py:873: UserWarning: qoi is not a List. Forcing it to be a list.\n",
" warnings.warn(\"qoi is not a List. Forcing it to be a list.\")\n",
" 40%|████ | 8/20 [00:04<00:07, 1.62it/s]"
" 75%|████████████████████████████████████ | 45/60 [00:06<00:01, 9.23it/s]"
]
},
{
Expand All @@ -2551,9 +2551,9 @@
" success: True\n",
" fun: 0.14687910929135808\n",
" x: [ 2.032e-01]\n",
" nit: 1\n",
" minimization_failures: 0\n",
" nfev: 8\n",
" nit: 5\n",
" minimization_failures: 4\n",
" nfev: 46\n",
" lowest_optimization_result: message: Optimization terminated successfully.\n",
" success: True\n",
" status: 1\n",
Expand All @@ -2562,13 +2562,6 @@
" nfev: 8\n",
" maxcv: 0.0}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
Expand All @@ -2586,6 +2579,14 @@
" param_value=[lambda x: torch.tensor(x)],\n",
" start_time=intervention_time,\n",
")\n",
"\n",
"# Progress bar\n",
"pbar = tqdm(total=maxfeval * (maxiter + 1))\n",
"\n",
"# update_progress(coordinate, function_min, accept)\n",
"def update_progress(xk):\n",
" pbar.update(1)\n",
"\n",
"opt_result = pyciemss.optimize(\n",
" model3,\n",
" end_time,\n",
Expand All @@ -2598,10 +2599,11 @@
" bounds_interventions=bounds_interventions,\n",
" start_time=0.0,\n",
" n_samples_ouu=int(1e2),\n",
" maxiter=0,\n",
" maxfeval=20,\n",
" maxiter=maxiter,\n",
" maxfeval=maxfeval,\n",
" solver_method=\"euler\",\n",
" solver_options={\"step_size\": logging_step_size / 2},\n",
" solver_options={\"step_size\": logging_step_size/2},\n",
" progress_hook=update_progress,\n",
")\n",
"print(f\"Optimal policy:\", opt_result[\"policy\"])\n",
"print(opt_result)"
Expand Down Expand Up @@ -3339,7 +3341,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.9"
}
},
"nbformat": 4,
Expand Down
6 changes: 6 additions & 0 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ def optimize(
maxfeval: int = 25,
verbose: bool = False,
roundup_decimal: int = 4,
progress_hook: Callable[[torch.Tensor], None] = lambda x: None,
) -> Dict[str, Any]:
r"""
Load a model from a file, compile it into a probabilistic program, and optimize under uncertainty with risk-based
Expand Down Expand Up @@ -858,6 +859,10 @@ def optimize(
- Whether to print out the optimization under uncertainty progress.
roundup_decimal: int
- Number of significant digits for the optimal policy.
progress_hook: progress_hook: Callable[[torch.Tensor], None],
- A callback function that takes in the current parameter vector as a tensor.
If the function returns StopIteration, the minimization will terminate.
- This can be used to implement custom progress bars and/or early stopping criteria.

Returns:
result: Dict[str, Any]
Expand Down Expand Up @@ -946,6 +951,7 @@ def objfun_penalty(x):
maxiter=maxiter,
maxfeval=maxfeval,
u_bounds=bounds_np,
progress_hook=progress_hook,
).solve()

# Rounding up to given number of decimal places
Expand Down
13 changes: 6 additions & 7 deletions pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.interventional.ops import Intervention
from scipy.optimize import basinhopping
from tqdm import tqdm

from pyciemss.integration_utils.intervention_builder import (
combine_static_parameter_interventions,
Expand Down Expand Up @@ -181,34 +180,34 @@ def __init__(
maxfeval: int = 100,
maxiter: int = 100,
u_bounds: np.ndarray = np.atleast_2d([[0], [1]]),
progress_hook: Callable[
[torch.Tensor], None
] = lambda x: None, # update_progress
):
self.x0 = np.squeeze(np.array([x0]))
self.objfun = objfun
self.constraints = constraints
self.maxiter = maxiter
self.maxfeval = maxfeval
self.u_bounds = u_bounds
self.progress_hook = progress_hook
# self.kwargs = kwargs

def solve(self):
pbar = tqdm(total=self.maxfeval * (self.maxiter + 1))

def update_progress(xk):
pbar.update(1)

# wrapper around SciPy optimizer(s)
# rhobeg is set to 10% of longest euclidean distance
minimizer_kwargs = dict(
constraints=self.constraints,
method="COBYLA",
tol=1e-5,
callback=update_progress,
options={
"rhobeg": 0.1
* np.linalg.norm(self.u_bounds[1, :] - self.u_bounds[0, :]),
"disp": False,
"maxiter": self.maxfeval,
"catol": 1e-5,
},
callback=self.progress_hook,
)
take_step = RandomDisplacementBounds(self.u_bounds[0, :], self.u_bounds[1, :])

Expand Down
19 changes: 18 additions & 1 deletion tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,14 +563,27 @@ def test_output_format(
def test_optimize(model_fixture, start_time, end_time, num_samples):
logging_step_size = 1.0
model_url = model_fixture.url

class TestProgressHook:
def __init__(self):
self.result_x = []

def __call__(self, x):
# Log the iteration number
self.result_x.append(x)
print(f"Result: {self.result_x}")

progress_hook = TestProgressHook()

optimize_kwargs = {
**model_fixture.optimize_kwargs,
"solver_method": "euler",
"solver_options": {"step_size": 0.1},
"start_time": start_time,
"n_samples_ouu": int(5),
"n_samples_ouu": int(2),
"maxiter": 1,
"maxfeval": 2,
"progress_hook": progress_hook,
}
bounds_interventions = optimize_kwargs["bounds_interventions"]
opt_result = optimize(
Expand Down Expand Up @@ -622,6 +635,10 @@ def test_optimize(model_fixture, start_time, end_time, num_samples):
intervened_result_subset, start_time, end_time, logging_step_size, num_samples
)

assert len(progress_hook.result_x) <= (
(optimize_kwargs["maxfeval"] + 1) * (optimize_kwargs["maxiter"] + 1)
)


@pytest.mark.parametrize("model_fixture", MODELS)
@pytest.mark.parametrize("bad_num_iterations", NON_POS_INTS)
Expand Down
Loading