Skip to content

Commit

Permalink
progress hook progress
Browse files Browse the repository at this point in the history
  • Loading branch information
sabinala committed Jul 22, 2024
1 parent 24369d8 commit a5aec8c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 23 deletions.
42 changes: 24 additions & 18 deletions docs/source/interfaces.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pyciemss\n",
"import torch\n",
"from tqdm import tqdm\n",
"import pandas as pd\n",
"import numpy as np\n",
"from typing import Dict, List, Callable\n",
Expand Down Expand Up @@ -2513,7 +2514,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -2533,19 +2534,23 @@
"end_time = 40.0\n",
"logging_step_size = 1.0\n",
"observed_params = [\"I_state\"]\n",
"intervened_params = [\"p_cbeta\"]"
"intervened_params = [\"p_cbeta\"]\n",
"\n",
"maxiter=3\n",
"maxfeval=20"
]
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 45%|████▌ | 9/20 [00:05<00:06, 1.77it/s]"
" 46%|██████████████████████▏ | 37/80 [00:41<00:47, 1.11s/it]\n",
" 45%|█████████████████████▋ | 38/84 [00:04<00:04, 9.46it/s]"
]
},
{
Expand All @@ -2557,9 +2562,9 @@
" success: True\n",
" fun: 0.2111072303557351\n",
" x: [ 1.389e-01]\n",
" nit: 1\n",
" minimization_failures: 0\n",
" nfev: 9\n",
" nit: 3\n",
" minimization_failures: 3\n",
" nfev: 38\n",
" lowest_optimization_result: message: Optimization terminated successfully.\n",
" success: True\n",
" status: 1\n",
Expand All @@ -2568,13 +2573,6 @@
" nfev: 9\n",
" maxcv: 0.0}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
Expand All @@ -2592,6 +2590,13 @@
" param_value = [lambda x: torch.tensor([x])],\n",
" start_time = intervention_time,\n",
")\n",
"\n",
"\n",
"pbar = tqdm(total=(maxfeval + 1) * (maxiter + 1))\n",
"\n",
"def update_progress(xk):\n",
" pbar.update(1)\n",
"\n",
"opt_result = pyciemss.optimize(\n",
" model3,\n",
" end_time,\n",
Expand All @@ -2604,10 +2609,11 @@
" bounds_interventions=bounds_interventions,\n",
" start_time=0.0,\n",
" n_samples_ouu=int(1e2),\n",
" maxiter=0,\n",
" maxfeval=20,\n",
" maxiter=maxiter,\n",
" maxfeval=maxfeval,\n",
" solver_method=\"euler\",\n",
" solver_options={\"step_size\": logging_step_size/2},\n",
" progress_hook=update_progress,\n",
")\n",
"print(f'Optimal policy:', opt_result[\"policy\"])\n",
"print(opt_result)"
Expand Down Expand Up @@ -3339,7 +3345,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.9"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,10 @@ def __init__(
# self.kwargs = kwargs

def solve(self):
pbar = tqdm(total=self.maxfeval * (self.maxiter + 1))
# pbar = tqdm(total=self.maxfeval * (self.maxiter + 1))

def update_progress(xk):
pbar.update(1)
# def update_progress(xk):
# pbar.update(1)

# wrapper around SciPy optimizer(s)
# rhobeg is set to 10% of longest euclidean distance
Expand Down
4 changes: 2 additions & 2 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,8 +581,8 @@ def __call__(self, iteration):
"solver_options": {"step_size": 0.1},
"start_time": start_time,
"n_samples_ouu": int(2),
"maxiter": 1,
"maxfeval": 2,
"maxiter": 5,
"maxfeval": 3,
"progress_hook": progress_hook,
}
bounds_interventions = optimize_kwargs["bounds_interventions"]
Expand Down

0 comments on commit a5aec8c

Please sign in to comment.