Skip to content

Commit

Permalink
progress hook progress
Browse files Browse the repository at this point in the history
  • Loading branch information
sabinala committed Jul 22, 2024
1 parent 672201f commit 24369d8
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
2 changes: 2 additions & 0 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ def optimize(
maxfeval: int = 25,
verbose: bool = False,
roundup_decimal: int = 4,
progress_hook: Callable = lambda i, feval: None,
) -> Dict[str, Any]:
r"""
Load a model from a file, compile it into a probabilistic program, and optimize under uncertainty with risk-based
Expand Down Expand Up @@ -932,6 +933,7 @@ def objfun_penalty(x):
maxiter=maxiter,
maxfeval=maxfeval,
u_bounds=bounds_np,
progress_hook=progress_hook,
).solve()

# Rounding up to given number of decimal places
Expand Down
4 changes: 3 additions & 1 deletion pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(
maxfeval: int = 100,
maxiter: int = 100,
u_bounds: np.ndarray = np.atleast_2d([[0], [1]]),
progress_hook: Callable = lambda i: None,
):
self.x0 = np.squeeze(np.array([x0]))
self.objfun = objfun
Expand All @@ -197,6 +198,7 @@ def __init__(
self.maxiter = maxiter
self.maxfeval = maxfeval
self.u_bounds = u_bounds
self.progress_hook = progress_hook
# self.kwargs = kwargs

def solve(self):
Expand All @@ -211,7 +213,7 @@ def update_progress(xk):
constraints=self.constraints,
method="COBYLA",
tol=1e-5,
callback=update_progress,
callback=self.progress_hook,
options={
"rhobeg": 0.1
* np.linalg.norm(self.u_bounds[1, :] - self.u_bounds[0, :]),
Expand Down
17 changes: 17 additions & 0 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,19 @@ def test_output_format(
def test_optimize(model_fixture, start_time, end_time, num_samples):
logging_step_size = 1.0
model_url = model_fixture.url

class TestProgressHook:
def __init__(self):
self.iterations = []
# self.function_evals = []

def __call__(self, iteration):
# Log the iteration number
self.iterations.append(iteration)
# self.function_evals.append(feval)

progress_hook = TestProgressHook()

optimize_kwargs = {
**model_fixture.optimize_kwargs,
"solver_method": "euler",
Expand All @@ -570,6 +583,7 @@ def test_optimize(model_fixture, start_time, end_time, num_samples):
"n_samples_ouu": int(2),
"maxiter": 1,
"maxfeval": 2,
"progress_hook": progress_hook,
}
bounds_interventions = optimize_kwargs["bounds_interventions"]
opt_result = optimize(
Expand Down Expand Up @@ -617,6 +631,9 @@ def test_optimize(model_fixture, start_time, end_time, num_samples):
intervened_result_subset, start_time, end_time, logging_step_size, num_samples
)

assert len(progress_hook.iterations) == (optimize_kwargs["maxfeval"] + 1) * (optimize_kwargs["maxiter"] + 1)
# assert len(progress_hook.function_evals) == optimize_kwargs["maxfeval"]


@pytest.mark.parametrize("model_fixture", MODELS)
@pytest.mark.parametrize("bad_num_iterations", NON_POS_INTS)
Expand Down

0 comments on commit 24369d8

Please sign in to comment.