From 5926fac9662fbaf65412dbee80012fefe192f03f Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 9 Oct 2025 18:41:03 +0200 Subject: [PATCH 01/34] Add FISTA adaptation and adapter as OptimistixFISTA --- src/nemos/solvers/__init__.py | 1 + src/nemos/solvers/_fista_port.py | 311 +++++++++++++++++++++++ src/nemos/solvers/_optimistix_solvers.py | 4 +- 3 files changed, 314 insertions(+), 2 deletions(-) create mode 100644 src/nemos/solvers/_fista_port.py diff --git a/src/nemos/solvers/__init__.py b/src/nemos/solvers/__init__.py index 71906ddeb..a72980898 100644 --- a/src/nemos/solvers/__init__.py +++ b/src/nemos/solvers/__init__.py @@ -20,3 +20,4 @@ glm_softplus_poisson_l_max_and_l, svrg_optimal_batch_and_stepsize, ) +from ._fista_port import OptimistixFISTA diff --git a/src/nemos/solvers/_fista_port.py b/src/nemos/solvers/_fista_port.py new file mode 100644 index 000000000..79ed1768d --- /dev/null +++ b/src/nemos/solvers/_fista_port.py @@ -0,0 +1,311 @@ +"""Adaptation of JAXopt's ProximalGradient (FISTA) as an Optimistix IterativeSolver.""" + +import operator +from typing import Any, Callable + +import equinox as eqx +import jax +import jax.numpy as jnp +import optimistix as optx +from jaxtyping import Array, Bool, Float, Int, PyTree +from optimistix._custom_types import Aux, Y + +from ._optimistix_solvers import OptimistixAdapter + + +def tree_sub(x, y): + return jax.tree.map(operator.sub, x, y) + + +# adapted from jaxopt +def tree_add_scalar_mul(tree_x: PyTree, scalar, tree_y): + return jax.tree.map(lambda x, y: x + scalar * y, tree_x, tree_y) + + +def tree_nan_like(x: PyTree): + return jax.tree.map(lambda arr: jnp.full_like(arr, jnp.nan), x) + + +class ProxGradState(eqx.Module): + iter_num: Int[Array, ""] + stepsize: Float[Array, ""] + velocity: PyTree + t: Float[Array, ""] + f: Float[Array, ""] + + terminate: Bool[Array, ""] + + +class FISTA(optx.AbstractMinimiser[Y, Aux, ProxGradState]): + prox: Callable + regularizer_strength: float + + atol: float + rtol: float + norm: Callable + + stepsize: float | None = None + maxls: int = 15 + decrease_factor: float = 0.5 + max_stepsize: float | None = 1.0 + + acceleration: bool = True + + def init( + self, + fn: Callable, + y: Y, + args: PyTree[Any], + options: dict[str, Any], + f_struct: PyTree[jax.ShapeDtypeStruct], + aux_struct: PyTree[jax.ShapeDtypeStruct], + tags: frozenset[object], + ) -> ProxGradState: + del options, f_struct, aux_struct, tags + fun_val, _ = fn(y, args) + + if self.acceleration: + vel = y + t = jnp.asarray(1.0) + else: + vel = tree_nan_like(y) + t = jnp.asarray(jnp.nan) + + return ProxGradState( + iter_num=jnp.asarray(0), + velocity=vel, + t=t, + stepsize=jnp.asarray(1.0), + terminate=jnp.asarray(False), + f=fun_val, + ) + + def step( + self, + fn: Callable, + y: Y, + args: PyTree[Any], + options: dict[str, Any], + state: ProxGradState, + tags: frozenset[object], + ) -> tuple[Y, ProxGradState, Aux]: + del tags + + # Some clarification on variable names because Optimistix's and the FISTA paper's / JAXopt's notation are different: + # + # In the paper x_{i} are the parameters, y_{i} the points after the momentum step. + # The updates: + # x_{k} = prox(y_{k} - stepsize_{k} * gradient_at_y_{k}) + # + # t_{k+1} = (1 + sqrt(1 + 4 t_{k}^2)) / 2 + # y_{k+1} = x_{k} + ((t_{k} - 1) / t_{k+1}) * (x_{k} - x_{k-1}) + # + # Where we run a linesearch to find the stepsize_{k} starting with stepsize_{k-1} / decrease_factor or 1. + # Note that instead of x_{k}, the current parameter values, the linsearch is done "looking out" from y_{k} + # in the direction of the gradient at that point. + # Also, the new t_{k+1} and y_{k+1} are precalculated to be used on the next iteration. + # + # In Optimistix the parameter values are denoted by y, + # so what is x in the formula above will be y in the code, + # and what is y in the formula will be called velocity or vel. + # + # In order of appearance in the code: + # y = x_{k-1} + # state.velocity = y_{k} # note that it was precalculated at the previous step + # state.stepsize = stepsize_{k-1} + # new_y = x_{k} + # new_stepsize = stepsize_{k} + # state.t = t_{k} # also precalculated at the previous step + # next_t = t_{k+1} + # next_vel = y_{k+1} + + if self.acceleration: + update_point = state.velocity + else: + update_point = y + + new_y, new_stepsize = self._update_at_point( + fn, update_point, args, options, state + ) + # TODO: These could be returned from _update_at_point + # because the linesearch already calculates it + # so a function evaluation could be saved here. + new_fun_val, new_aux = fn(new_y, args) + diff_y = tree_sub(new_y, y) + + if self.acceleration: + next_t = 0.5 * (1 + jnp.sqrt(1 + 4 * state.t**2)) + next_vel = tree_add_scalar_mul(new_y, (state.t - 1) / next_t, diff_y) + else: + next_t = state.t + next_vel = state.velocity + + # NOTE do we want to use Cauchy for consistency with other solvers + # or the other to be consistent with JAXopt? + terminate = optx._misc.cauchy_termination( + self.rtol, + self.atol, + self.norm, + y, + diff_y, + state.f, + new_fun_val - state.f, + ) + # terminate = (optx.two_norm(diff_y) / new_stepsize) < self.atol + + next_state = ProxGradState( + iter_num=state.iter_num + 1, + velocity=next_vel, + t=next_t, + stepsize=jnp.asarray(new_stepsize), + terminate=terminate, + f=new_fun_val, + ) + + return new_y, next_state, new_aux + + def _update_at_point( + self, + fn: Callable, + update_point: Y, + args: PyTree[Any], + options: dict[str, Any], + state: ProxGradState, + ): + """ + Perform the update with or without linesearch around `update_point`. + + If acceleration is used (FISTA), `update_point` is state.velocity ~ y_{k}. + Without acceleration (ISTA) `update_point` is `y` ~ x_{k-1}. + """ + autodiff_mode = options.get("autodiff_mode", "bwd") + f_at_point, lin_fn, _ = jax.linearize( + lambda _y: fn(_y, args), update_point, has_aux=True + ) + grad_at_point = optx._misc.lin_to_grad( + lin_fn, update_point, autodiff_mode=autodiff_mode + ) + + if self.stepsize is None or self.stepsize <= 0.0: + # do linesearch to find the new stepsize + fun_without_aux = lambda params, args: fn(params, args)[0] + new_y, new_stepsize = self.fista_line_search( + fun_without_aux, + update_point, + f_at_point, + grad_at_point, + state.stepsize, + args, + ) + + # attempt to increase the stepsize for the new linesearch + # or reset it if it's very small + new_stepsize = jnp.where( + new_stepsize <= 1e-6, + jnp.array(1.0), + new_stepsize / self.decrease_factor, + ) + # B: in my experience, this guard helps stabilize and reduce the number of iterations + # For some reason, without it this implementation sometimes needs more iterations than the + # original JAXopt implementation, which in theory should be mathematically identical. + if self.max_stepsize is not None: + new_stepsize = jnp.minimum(new_stepsize, self.max_stepsize) + else: + # use the fixed stepsize + new_stepsize = self.stepsize + new_y = tree_add_scalar_mul(update_point, -new_stepsize, grad_at_point) + new_y = self.prox(new_y, self.regularizer_strength, new_stepsize) + + return new_y, new_stepsize + + # adapted from JAXopt + def fista_line_search( + self, + fun: Callable, + x: Y, + x_fun_val: Float[Array, ""], + grad: Y, + stepsize: Float[Array, ""], + args: PyTree[Any], + ) -> tuple[Y, Float[Array, ""]]: + # epsilon of current dtype for robust checking of + # sufficient decrease condition + eps = jnp.finfo(x_fun_val.dtype).eps + + def cond_fun(carry): + next_x, stepsize = carry + + new_fun_val = fun(next_x, args) + + diff_x = tree_sub(next_x, x) + sqdist = optx._misc.sum_squares(diff_x) + + # verbatim from JAXopt + # The expression below checks the sufficient decrease condition + # f(next_x) < f(x) + dot(grad_f(x), diff_x) + (0.5/stepsize) ||diff_x||^2 + # where the terms have been reordered for numerical stability. + fun_decrease = stepsize * (new_fun_val - x_fun_val) + expected_decrease = ( + stepsize * optx._misc.tree_dot(diff_x, grad) + 0.5 * sqdist + ) + + return fun_decrease > expected_decrease + eps + + def body_fun(carry): + stepsize = carry[1] + new_stepsize = stepsize * self.decrease_factor + next_x = tree_add_scalar_mul(x, -new_stepsize, grad) + next_x = self.prox(next_x, self.regularizer_strength, new_stepsize) + return next_x, new_stepsize + + init_x = tree_add_scalar_mul(x, -stepsize, grad) + init_x = self.prox(init_x, self.regularizer_strength, stepsize) + init_val = (init_x, stepsize) + + # TODO: make kind dependent on the adjoint used? + # "lax" for implicit, "checkpointed" for RecursiveCheckpointAdjoint + # could make a partial function where these are available + # or just accept it in __init__? + return eqx.internal.while_loop( + cond_fun=cond_fun, + body_fun=body_fun, + init_val=init_val, + max_steps=self.maxls, + kind="lax", + ) + + def terminate( + self, + fn: Callable, + y: Y, + args: PyTree[Any], + options: dict[str, Any], + state: ProxGradState, + tags: frozenset[object], + ) -> tuple[Bool[Array, ""], optx._solution.RESULTS]: + del fn, y, args, options, tags + + return state.terminate, optx._solution.RESULTS.successful + + def postprocess( + self, + fn: Callable, + y: Y, + aux: Aux, + args: PyTree[Any], + options: dict[str, Any], + state: ProxGradState, + tags: frozenset[object], + result: optx._solution.RESULTS, + ) -> tuple[Y, Aux, dict[str, Any]]: + del fn, args, options, state, tags, result + return y, aux, {} + + +class OptimistixFISTA(OptimistixAdapter): + _solver_cls = FISTA + _proximal = True + + @property + def maxiter(self): + return self.config.maxiter diff --git a/src/nemos/solvers/_optimistix_solvers.py b/src/nemos/solvers/_optimistix_solvers.py index e6ce2ff3e..6706681ed 100644 --- a/src/nemos/solvers/_optimistix_solvers.py +++ b/src/nemos/solvers/_optimistix_solvers.py @@ -99,8 +99,8 @@ def __init__( if self._proximal: loss_fn = unregularized_loss - self.prox = regularizer.get_proximal_operator() - self.regularizer_strength = regularizer_strength + solver_init_kwargs["prox"] = regularizer.get_proximal_operator() + solver_init_kwargs["regularizer_strength"] = regularizer_strength else: loss_fn = regularizer.penalized_loss( unregularized_loss, regularizer_strength From 618e769a5db9204753fbfac270160596f3933123 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 9 Oct 2025 18:42:50 +0200 Subject: [PATCH 02/34] Use OptimistixFISTA when using the Optimistix backend in tests --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 179334ec4..e79144a2c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1129,7 +1129,7 @@ def population_negativeBinomialGLM_model_instantiation_pytree( "optimistix": { **_common_solvers, "GradientDescent": nmo.solvers.OptimistixOptaxGradientDescent, - "ProximalGradient": nmo.solvers.OptimistixOptaxProximalGradient, + "ProximalGradient": nmo.solvers.OptimistixFISTA, "LBFGS": nmo.solvers.OptimistixOptaxLBFGS, "BFGS": nmo.solvers.OptimistixBFGS, "NonlinearCG": nmo.solvers.OptimistixNonlinearCG, From 139990da17778d736b658c6dd26678a0b0ca4576 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 9 Oct 2025 20:52:07 +0200 Subject: [PATCH 03/34] Add OptimistixGradientDescent that uses the accelerated JAXopt port --- src/nemos/solvers/__init__.py | 2 +- src/nemos/solvers/_fista_port.py | 34 ++++++++++++++++++++++---------- tests/conftest.py | 2 +- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/nemos/solvers/__init__.py b/src/nemos/solvers/__init__.py index a72980898..715f30507 100644 --- a/src/nemos/solvers/__init__.py +++ b/src/nemos/solvers/__init__.py @@ -20,4 +20,4 @@ glm_softplus_poisson_l_max_and_l, svrg_optimal_batch_and_stepsize, ) -from ._fista_port import OptimistixFISTA +from ._fista_port import OptimistixFISTA, OptimistixGradientDescent diff --git a/src/nemos/solvers/_fista_port.py b/src/nemos/solvers/_fista_port.py index 79ed1768d..59f83be21 100644 --- a/src/nemos/solvers/_fista_port.py +++ b/src/nemos/solvers/_fista_port.py @@ -1,7 +1,7 @@ """Adaptation of JAXopt's ProximalGradient (FISTA) as an Optimistix IterativeSolver.""" import operator -from typing import Any, Callable +from typing import Any, Callable, ClassVar import equinox as eqx import jax @@ -13,6 +13,12 @@ from ._optimistix_solvers import OptimistixAdapter +def prox_none(x, hyperparams=None, scaling: float = 1.0): + """Identity proximal operator.""" + del hyperparams, scaling + return x + + def tree_sub(x, y): return jax.tree.map(operator.sub, x, y) @@ -37,13 +43,13 @@ class ProxGradState(eqx.Module): class FISTA(optx.AbstractMinimiser[Y, Aux, ProxGradState]): - prox: Callable - regularizer_strength: float - atol: float rtol: float norm: Callable + prox: Callable + regularizer_strength: float | None + stepsize: float | None = None maxls: int = 15 decrease_factor: float = 0.5 @@ -140,8 +146,7 @@ def step( next_t = state.t next_vel = state.velocity - # NOTE do we want to use Cauchy for consistency with other solvers - # or the other to be consistent with JAXopt? + # use Cauchy for consistency with other solvers terminate = optx._misc.cauchy_termination( self.rtol, self.atol, @@ -151,7 +156,6 @@ def step( state.f, new_fun_val - state.f, ) - # terminate = (optx.two_norm(diff_y) / new_stepsize) < self.atol next_state = ProxGradState( iter_num=state.iter_num + 1, @@ -302,10 +306,20 @@ def postprocess( return y, aux, {} +class GradientDescent(FISTA): + prox: ClassVar[Callable] = staticmethod(prox_none) + regularizer_strength: float | None = None + + class OptimistixFISTA(OptimistixAdapter): + """Port of JAXopt's ProximalGradient to the Optimistix API.""" + _solver_cls = FISTA _proximal = True - @property - def maxiter(self): - return self.config.maxiter + +class OptimistixGradientDescent(OptimistixAdapter): + """Port of JAXopt's accelerated GradientDescent to the Optimistix API.""" + + _solver_cls = GradientDescent + _proximal = False diff --git a/tests/conftest.py b/tests/conftest.py index e79144a2c..6e8816034 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1128,7 +1128,7 @@ def population_negativeBinomialGLM_model_instantiation_pytree( }, "optimistix": { **_common_solvers, - "GradientDescent": nmo.solvers.OptimistixOptaxGradientDescent, + "GradientDescent": nmo.solvers.OptimistixGradientDescent, "ProximalGradient": nmo.solvers.OptimistixFISTA, "LBFGS": nmo.solvers.OptimistixOptaxLBFGS, "BFGS": nmo.solvers.OptimistixBFGS, From bb94b25c5f30e6b71059959a80f156f4bdc0a505 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 9 Oct 2025 20:52:38 +0200 Subject: [PATCH 04/34] Use two_norm by default in Cauchy criterion --- src/nemos/solvers/_optimistix_solvers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/solvers/_optimistix_solvers.py b/src/nemos/solvers/_optimistix_solvers.py index 6706681ed..e573c355c 100644 --- a/src/nemos/solvers/_optimistix_solvers.py +++ b/src/nemos/solvers/_optimistix_solvers.py @@ -52,7 +52,7 @@ class OptimistixConfig: # sets if the minimisation throws an error if an iterative solver runs out of steps throw: bool = False # norm used in the Cauchy convergence criterion. Required by all Optimistix solvers. - norm: Callable = optx.max_norm + norm: Callable = optx.two_norm # way of autodifferentiation: https://docs.kidger.site/optimistix/api/adjoints/ adjoint: optx.AbstractAdjoint = optx.ImplicitAdjoint() # whether the objective function returns any auxiliary results. From ab09fe596aa1ed2f417ded5f7c37fab31f3a0551 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 9 Oct 2025 20:52:55 +0200 Subject: [PATCH 05/34] Remove prox from accepted arguments in Optimistix solvers In JAXopt solvers it was already removed. It is created by the regularizer, not settable by users. --- src/nemos/solvers/_optimistix_solvers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/nemos/solvers/_optimistix_solvers.py b/src/nemos/solvers/_optimistix_solvers.py index e573c355c..0e6f42611 100644 --- a/src/nemos/solvers/_optimistix_solvers.py +++ b/src/nemos/solvers/_optimistix_solvers.py @@ -189,6 +189,10 @@ def get_accepted_arguments(cls) -> set[str]: ) all_arguments = own_and_solver_args | common_optx_arguments + # prox is read from the regularizer, not provided as a solver argument + if cls._proximal: + all_arguments.remove("prox") + return all_arguments @classmethod From 9d630e0809eb6de303920068eea1e52292854e5a Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 9 Oct 2025 20:56:49 +0200 Subject: [PATCH 06/34] Add todo --- src/nemos/solvers/_fista_port.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/nemos/solvers/_fista_port.py b/src/nemos/solvers/_fista_port.py index 59f83be21..af027a27d 100644 --- a/src/nemos/solvers/_fista_port.py +++ b/src/nemos/solvers/_fista_port.py @@ -12,6 +12,8 @@ from ._optimistix_solvers import OptimistixAdapter +# TODO: Add detailed docstrings + def prox_none(x, hyperparams=None, scaling: float = 1.0): """Identity proximal operator.""" From 49fd795b48dfbbed4ef51f1153378bda61b6c51b Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Fri, 10 Oct 2025 10:59:03 +0200 Subject: [PATCH 07/34] Add docstrings --- src/nemos/solvers/_fista_port.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/nemos/solvers/_fista_port.py b/src/nemos/solvers/_fista_port.py index af027a27d..c5d76a67e 100644 --- a/src/nemos/solvers/_fista_port.py +++ b/src/nemos/solvers/_fista_port.py @@ -35,6 +35,8 @@ def tree_nan_like(x: PyTree): class ProxGradState(eqx.Module): + """ProximalGradient (FISTA) solver state.""" + iter_num: Int[Array, ""] stepsize: Float[Array, ""] velocity: PyTree @@ -45,6 +47,17 @@ class ProxGradState(eqx.Module): class FISTA(optx.AbstractMinimiser[Y, Aux, ProxGradState]): + """ + Accelerated Proximal Gradient (FISTA) as an Optimistix minimiser. Adapted from JAXopt. + + References + ---------- + .. [1] Beck, A., & Teboulle, M. (2009). + "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse Problems." + *SIAM Journal on Imaging Sciences*, 2(1), 183–202. + https://doi.org/10.1137/080716542 + """ + atol: float rtol: float norm: Callable @@ -309,6 +322,8 @@ def postprocess( class GradientDescent(FISTA): + """Gradient descent with Nesterov acceleration. Adapted from JAXopt.""" + prox: ClassVar[Callable] = staticmethod(prox_none) regularizer_strength: float | None = None From 85a26d02a88706bdf29e5855fa13805c117458ea Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Fri, 10 Oct 2025 11:26:41 +0200 Subject: [PATCH 08/34] OptimistixGradientDescent -> OptimistixNAG --- docs/developers_notes/07-solvers.md | 2 ++ src/nemos/solvers/__init__.py | 2 +- src/nemos/solvers/_fista_port.py | 4 ++-- tests/conftest.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/developers_notes/07-solvers.md b/docs/developers_notes/07-solvers.md index 0744ecc57..cf601c12a 100644 --- a/docs/developers_notes/07-solvers.md +++ b/docs/developers_notes/07-solvers.md @@ -66,6 +66,8 @@ Abstract Class AbstractSolver │ │ │ │ │ ├─ Concrete Subclass OptimistixBFGS │ │ ├─ Concrete Subclass OptimistixLBFGS +│ │ ├─ Concrete Subclass OptimistixFISTA +│ │ ├─ Concrete Subclass OptimistixNAG │ │ ├─ Concrete Subclass OptimistixNonlinearCG │ │ └─ Abstract Subclass AbstractOptimistixOptaxSolver │ │ │ diff --git a/src/nemos/solvers/__init__.py b/src/nemos/solvers/__init__.py index 715f30507..5b8e7800a 100644 --- a/src/nemos/solvers/__init__.py +++ b/src/nemos/solvers/__init__.py @@ -20,4 +20,4 @@ glm_softplus_poisson_l_max_and_l, svrg_optimal_batch_and_stepsize, ) -from ._fista_port import OptimistixFISTA, OptimistixGradientDescent +from ._fista_port import OptimistixFISTA, OptimistixNAG diff --git a/src/nemos/solvers/_fista_port.py b/src/nemos/solvers/_fista_port.py index c5d76a67e..4cf6e7c7a 100644 --- a/src/nemos/solvers/_fista_port.py +++ b/src/nemos/solvers/_fista_port.py @@ -335,8 +335,8 @@ class OptimistixFISTA(OptimistixAdapter): _proximal = True -class OptimistixGradientDescent(OptimistixAdapter): - """Port of JAXopt's accelerated GradientDescent to the Optimistix API.""" +class OptimistixNAG(OptimistixAdapter): + """Port of Nesterov's accelerated gradient descent from JAXopt to the Optimistix API.""" _solver_cls = GradientDescent _proximal = False diff --git a/tests/conftest.py b/tests/conftest.py index 6e8816034..5fa45db97 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1128,7 +1128,7 @@ def population_negativeBinomialGLM_model_instantiation_pytree( }, "optimistix": { **_common_solvers, - "GradientDescent": nmo.solvers.OptimistixGradientDescent, + "GradientDescent": nmo.solvers.OptimistixNAG, "ProximalGradient": nmo.solvers.OptimistixFISTA, "LBFGS": nmo.solvers.OptimistixOptaxLBFGS, "BFGS": nmo.solvers.OptimistixBFGS, From 12a6f219894940c5aa478e4e426ad64c5c5b63f1 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Fri, 10 Oct 2025 11:27:51 +0200 Subject: [PATCH 09/34] Remove OptimistixOptaxProximalGradient Also correct OptimistixOptax... names in the dev notes --- docs/developers_notes/07-solvers.md | 3 +- src/nemos/solvers/__init__.py | 1 - .../solvers/_optax_optimistix_solvers.py | 170 ------------------ 3 files changed, 1 insertion(+), 173 deletions(-) diff --git a/docs/developers_notes/07-solvers.md b/docs/developers_notes/07-solvers.md index cf601c12a..b20e54408 100644 --- a/docs/developers_notes/07-solvers.md +++ b/docs/developers_notes/07-solvers.md @@ -72,8 +72,7 @@ Abstract Class AbstractSolver │ │ └─ Abstract Subclass AbstractOptimistixOptaxSolver │ │ │ │ │ ├─ Concrete Subclass OptimistixOptaxLBFGS -│ │ ├─ Concrete Subclass OptimistixOptaxGradientDescent -│ │ └─ Concrete Subclass OptimistixOptaxProximalGradient +│ │ └─ Concrete Subclass OptimistixOptaxGradientDescent │ │ │ └─ Abstract Subclass JaxoptAdapter │ │ diff --git a/src/nemos/solvers/__init__.py b/src/nemos/solvers/__init__.py index 5b8e7800a..ca1294480 100644 --- a/src/nemos/solvers/__init__.py +++ b/src/nemos/solvers/__init__.py @@ -10,7 +10,6 @@ from ._optax_optimistix_solvers import ( OptimistixOptaxGradientDescent, OptimistixOptaxLBFGS, - OptimistixOptaxProximalGradient, ) from ._optimistix_solvers import OptimistixBFGS, OptimistixNonlinearCG from ._solver_doc_helper import get_solver_documentation diff --git a/src/nemos/solvers/_optax_optimistix_solvers.py b/src/nemos/solvers/_optax_optimistix_solvers.py index 701b7092e..fa4786d30 100644 --- a/src/nemos/solvers/_optax_optimistix_solvers.py +++ b/src/nemos/solvers/_optax_optimistix_solvers.py @@ -215,176 +215,6 @@ def _note_about_accepted_arguments(cls) -> str: return inspect.cleandoc(note + "\n" + accel_nesterov) -class OptimistixOptaxProximalGradient(AbstractOptimistixOptaxSolver): - """ - ProximalGradient implementation combining Optax and Optimistix. - - Uses Optax's SGD combined with Optax's zoom linesearch or a constant learning rate. - Then uses the learning rate given by Optax to scale the proximal - operator's update and check for convergence using Optimistix's criterion. - - Works with the same proximal operator functions as JAXopt did. - """ - - fun: Callable - fun_with_aux: Callable - prox: Callable - - # stats: dict[str, PyTree[ArrayLike]] - stats: dict[str, Pytree] - - _proximal: ClassVar[bool] = True - - def __init__( - self, - unregularized_loss: Callable, - regularizer: Regularizer, - regularizer_strength: float | None, - tol: float = DEFAULT_ATOL, - rtol: float = DEFAULT_RTOL, - maxiter: float = DEFAULT_MAX_STEPS, - momentum: float | None = None, - acceleration: bool = True, - stepsize: float | None = None, - linesearch_kwargs: dict | None = None, - **solver_init_kwargs, - ): - """ - Create a proximal gradient solver using `optax.sgd` and applying the proximal operator on each update step. - - If `acceleration` is True, use the Nesterov acceleration as defined by Sutskever et al. 2013. - Note that this is different from the Nesterov acceleration implemented by JAXopt and - only has an effect if `momentum` is used as well. - - If `stepsize` is not None and larger than 0, use it as a constant learning rate. - Otherwise `optax.scale_by_zoom_linesearch` is used with the curvature condition - disabled, which reduces to a backtracking linesearch where the stepsizes are chosen - using the cubic or quadratic interpolation used in zoom linesearch. - By default, 15 linesearch steps are used, which can be overwritten with - `max_linesearch_steps` in `linesearch_kwargs`. - - References - ---------- - [1] [Sutskever, I., Martens, J., Dahl, G. & Hinton, G.. (2013). - "On the importance of initialization and momentum in deep learning." - Proceedings of the 30th International Conference on Machine Learning, PMLR 28(3):1139-1147, 2013. - ](https://proceedings.mlr.press/v28/sutskever13.html) - """ - _sgd = optax.chain( - optax.sgd(learning_rate=1.0, momentum=momentum, nesterov=acceleration), - _make_rate_scaler(stepsize, linesearch_kwargs), - ) - solver_init_kwargs["optim"] = _sgd - - super().__init__( - unregularized_loss, - regularizer, - regularizer_strength, - tol=tol, - rtol=rtol, - maxiter=maxiter, - **solver_init_kwargs, - ) - - @classmethod - def get_accepted_arguments(cls) -> set[str]: - arguments = super().get_accepted_arguments() - - arguments.discard("optim") # we create this, it can't be passed - - return arguments - - def get_learning_rate(self, state: optx._solver.optax._OptaxState) -> float: - """ - Read out the learning rate for scaling within the proximal operator. - - This learning rate is either a static learning rate or was found by a linesearch. - """ - return state.opt_state[-1].learning_rate - - def step( - self, - fn: Callable, - y: Params, - args: Pytree, - options: dict[str, Any], - state: optx._solver.optax._OptaxState, - tags: frozenset[object], - ): - # take gradient step - new_params, new_state, new_aux = self._solver.step( - fn, y, args, options, state, tags - ) - - # apply the proximal operator - new_params = self.prox( - new_params, - self.regularizer_strength, - self.get_learning_rate(new_state), - ) - - # reevaluate function value at the new point - new_state = eqx.tree_at(lambda s: s.f, new_state, fn(new_params, args)[0]) - - # recheck convergence criteria with the projected point - updates = tree_sub(new_params, y) - - # replicating the jaxopt stopping criterion - terminate = ( - optx.two_norm(updates) / self.get_learning_rate(new_state) - < self._solver.atol - ) - - new_state = eqx.tree_at(lambda s: s.terminate, new_state, terminate) - - return new_params, new_state, new_aux - - def run( - self, - init_params: Params, - *args, - ) -> OptimistixStepResult: - solution = optx.minimise( - fn=self.fun, - solver=self, # pyright: ignore - y0=init_params, - args=args, - options=self.config.options, - has_aux=self.config.has_aux, - max_steps=self.config.maxiter, - adjoint=self.config.adjoint, - throw=self.config.throw, - tags=self.config.tags, - ) - - self.stats.update(solution.stats) - - return solution.value, solution.state - - def init(self, *args, **kwargs): - # so that when passing self to optx.minimise, init can be called - return self._solver.init(*args, **kwargs) - - def terminate(self, *args, **kwargs): - # so that when passing self to optx.minimise, terminate can be called - return self._solver.terminate(*args, **kwargs) - - def postprocess(self, *args, **kwargs): - # so that when passing self to optx.minimise, postprocess can be called - return self._solver.postprocess(*args, **kwargs) - - @classmethod - def _note_about_accepted_arguments(cls) -> str: - note = super()._note_about_accepted_arguments() - accel_nesterov = inspect.cleandoc( - """ - `acceleration` is passed to `optax.sgd` as the `nesterov` parameter. - Note that this only has an effect if `momentum` is used as well. - """ - ) - return inspect.cleandoc(note + "\n" + accel_nesterov) - - class OptimistixOptaxLBFGS(AbstractOptimistixOptaxSolver): """ L-BFGS implementation using optax.lbfgs wrapped by optimistix.OptaxMinimiser. From 22663839efe59b483a9a77e2adaee3fd06c359c1 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Fri, 10 Oct 2025 15:07:27 +0200 Subject: [PATCH 10/34] Use tree_sub from tree_utils --- src/nemos/solvers/_fista_port.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/nemos/solvers/_fista_port.py b/src/nemos/solvers/_fista_port.py index 4cf6e7c7a..b2e13567b 100644 --- a/src/nemos/solvers/_fista_port.py +++ b/src/nemos/solvers/_fista_port.py @@ -1,6 +1,5 @@ """Adaptation of JAXopt's ProximalGradient (FISTA) as an Optimistix IterativeSolver.""" -import operator from typing import Any, Callable, ClassVar import equinox as eqx @@ -9,6 +8,7 @@ import optimistix as optx from jaxtyping import Array, Bool, Float, Int, PyTree from optimistix._custom_types import Aux, Y +from ..tree_utils import tree_sub from ._optimistix_solvers import OptimistixAdapter @@ -21,10 +21,6 @@ def prox_none(x, hyperparams=None, scaling: float = 1.0): return x -def tree_sub(x, y): - return jax.tree.map(operator.sub, x, y) - - # adapted from jaxopt def tree_add_scalar_mul(tree_x: PyTree, scalar, tree_y): return jax.tree.map(lambda x, y: x + scalar * y, tree_x, tree_y) From ebf8b00638e8ccb19456fb02e94e7b26736d0736 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Fri, 10 Oct 2025 15:16:58 +0200 Subject: [PATCH 11/34] Use tree_add_scalar_mul from tree_utils --- src/nemos/solvers/_fista_port.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/nemos/solvers/_fista_port.py b/src/nemos/solvers/_fista_port.py index b2e13567b..8b29f3561 100644 --- a/src/nemos/solvers/_fista_port.py +++ b/src/nemos/solvers/_fista_port.py @@ -8,8 +8,8 @@ import optimistix as optx from jaxtyping import Array, Bool, Float, Int, PyTree from optimistix._custom_types import Aux, Y -from ..tree_utils import tree_sub +from ..tree_utils import tree_add_scalar_mul, tree_sub from ._optimistix_solvers import OptimistixAdapter # TODO: Add detailed docstrings @@ -21,11 +21,6 @@ def prox_none(x, hyperparams=None, scaling: float = 1.0): return x -# adapted from jaxopt -def tree_add_scalar_mul(tree_x: PyTree, scalar, tree_y): - return jax.tree.map(lambda x, y: x + scalar * y, tree_x, tree_y) - - def tree_nan_like(x: PyTree): return jax.tree.map(lambda arr: jnp.full_like(arr, jnp.nan), x) From 10401cd5dcbd987754755aefe0a712fc97f155d1 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Fri, 10 Oct 2025 15:18:04 +0200 Subject: [PATCH 12/34] Switch to backtracking in OptimistixOptaxGradientDescent Also remove unused code. --- .../solvers/_optax_optimistix_solvers.py | 66 ++++--------------- 1 file changed, 12 insertions(+), 54 deletions(-) diff --git a/src/nemos/solvers/_optax_optimistix_solvers.py b/src/nemos/solvers/_optax_optimistix_solvers.py index fa4786d30..cf9168658 100644 --- a/src/nemos/solvers/_optax_optimistix_solvers.py +++ b/src/nemos/solvers/_optax_optimistix_solvers.py @@ -2,24 +2,18 @@ import abc import inspect -from typing import Any, Callable, ClassVar, NamedTuple, Union +from typing import Any, Callable, ClassVar -import equinox as eqx -import jax -import jax.numpy as jnp import optax import optimistix as optx from ..regularizer import Regularizer -from ..tree_utils import tree_sub from ..typing import Pytree from ._optimistix_solvers import ( DEFAULT_ATOL, DEFAULT_MAX_STEPS, DEFAULT_RTOL, OptimistixAdapter, - OptimistixStepResult, - Params, ) @@ -71,34 +65,7 @@ def __init_subclass__(cls, **kwargs): cls.__doc__ = inspect.cleandoc(full_doc) -class ScaleByLearningRateState(NamedTuple): - learning_rate: Union[float, jax.Array] - - -def stateful_scale_by_learning_rate( - stepsize: float, flip_sign: bool = True -) -> optax.GradientTransformation: - """ - Reimplementation of optax.scale_by_learning_rate, just storing the learning rate in the state. - - Required for setting the scaling appropriately when used with - proximal gradient descent. - """ - m = -1 if flip_sign else 1 - - def init_fn(params): - del params - return ScaleByLearningRateState(jnp.array(stepsize)) - - def update_fn(updates, state, params=None): - del params - updates = jax.tree.map(lambda g: m * stepsize * g, updates) - - return updates, state - - return optax.GradientTransformation(init_fn, update_fn) # pyright: ignore - - +# TODO: Figure out how to test both OptimistixOptaxGradientDescent and OptimistixNAG def _make_rate_scaler( stepsize: float | None, linesearch_kwargs: dict[str, Any] | None, @@ -107,37 +74,30 @@ def _make_rate_scaler( Make an Optax transformation for setting the learning rate. If `stepsize` is not None and larger than 0, use it as a constant learning rate. - Otherwise `optax.scale_by_zoom_linesearch` is used with `linesearch_kwargs`. - By default the curvature condition is disabled, which reduces to a backtracking - linesearch where the stepsizes are chosen using the cubic or quadratic interpolation - used in zoom linesearch. + Otherwise `optax.scale_by_backtracking_linesearch` is used with `linesearch_kwargs`. By default, 15 linesearch steps are used, which can be overwritten with - `max_linesearch_steps` in `linesearch_kwargs`. + `max_backtracking_steps` in `linesearch_kwargs`. """ if stepsize is None or stepsize <= 0.0: if linesearch_kwargs is None: linesearch_kwargs = {} - if "max_linesearch_steps" not in linesearch_kwargs: - linesearch_kwargs["max_linesearch_steps"] = 15 - - if "curv_rtol" not in linesearch_kwargs: - linesearch_kwargs["curv_rtol"] = jnp.inf + if "max_backtracking_steps" not in linesearch_kwargs: + linesearch_kwargs["max_backtracking_steps"] = 15 - return optax.scale_by_zoom_linesearch(**linesearch_kwargs) + return optax.scale_by_backtracking_linesearch(**linesearch_kwargs) else: if linesearch_kwargs: raise ValueError("Only provide stepsize or linesearch_kwargs.") - # GradientDescent works with optax.scale_by_learning_rate as well - # but for ProximalGradient we need to be able to extract the current learning rate - return stateful_scale_by_learning_rate(stepsize) + + return optax.scale_by_learning_rate(stepsize) class OptimistixOptaxGradientDescent(AbstractOptimistixOptaxSolver): """ Gradient descent implementation combining Optax and Optimistix. - Uses Optax's SGD combined with Optax's zoom linesearch or a constant learning rate. + Uses Optax's SGD combined with Optax's backtracking linesearch or a constant learning rate. The full optimization loop is handled by the `optimistix.OptaxMinimiser` wrapper. """ @@ -166,11 +126,9 @@ def __init__( only has an effect if `momentum` is used as well. If `stepsize` is not None and larger than 0, use it as a constant learning rate. - Otherwise `optax.scale_by_zoom_linesearch` is used with the curvature condition - disabled, which reduces to a backtracking linesearch where the stepsizes are chosen - using the cubic or quadratic interpolation used in zoom linesearch. + Otherwise `optax.scale_by_backtracking_linesearch` is used. By default, 15 linesearch steps are used, which can be overwritten with - `max_linesearch_steps` in `linesearch_kwargs`. + `max_backtracking_steps` in `linesearch_kwargs`. References ---------- From 909a3d0f66f1fed48deac0fb228a6d0d722e064c Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Fri, 10 Oct 2025 15:18:43 +0200 Subject: [PATCH 13/34] Satisfy linter --- scripts/check_parameter_naming.py | 1 + src/nemos/solvers/__init__.py | 2 +- src/nemos/solvers/_fista_port.py | 6 +++--- tox.ini | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/scripts/check_parameter_naming.py b/scripts/check_parameter_naming.py index 0454934ca..af1d33424 100644 --- a/scripts/check_parameter_naming.py +++ b/scripts/check_parameter_naming.py @@ -83,6 +83,7 @@ {"solver_kwargs", "solver_init_kwargs"}, {"unaccepted_name", "accepted_name"}, {"fn", "fun"}, + {"ax", "aux"}, ] diff --git a/src/nemos/solvers/__init__.py b/src/nemos/solvers/__init__.py index ca1294480..57c542555 100644 --- a/src/nemos/solvers/__init__.py +++ b/src/nemos/solvers/__init__.py @@ -1,5 +1,6 @@ """Custom solvers module.""" +from ._fista_port import OptimistixFISTA, OptimistixNAG from ._jaxopt_solvers import ( JaxoptBFGS, JaxoptGradientDescent, @@ -19,4 +20,3 @@ glm_softplus_poisson_l_max_and_l, svrg_optimal_batch_and_stepsize, ) -from ._fista_port import OptimistixFISTA, OptimistixNAG diff --git a/src/nemos/solvers/_fista_port.py b/src/nemos/solvers/_fista_port.py index 8b29f3561..d0ae2be12 100644 --- a/src/nemos/solvers/_fista_port.py +++ b/src/nemos/solvers/_fista_port.py @@ -103,7 +103,8 @@ def step( ) -> tuple[Y, ProxGradState, Aux]: del tags - # Some clarification on variable names because Optimistix's and the FISTA paper's / JAXopt's notation are different: + # Some clarification on variable names because Optimistix's and the FISTA paper's / JAXopt's + # notation are different: # # In the paper x_{i} are the parameters, y_{i} the points after the momentum step. # The updates: @@ -198,9 +199,8 @@ def _update_at_point( if self.stepsize is None or self.stepsize <= 0.0: # do linesearch to find the new stepsize - fun_without_aux = lambda params, args: fn(params, args)[0] new_y, new_stepsize = self.fista_line_search( - fun_without_aux, + lambda params, args: fn(params, args)[0], update_point, f_at_point, grad_at_point, diff --git a/tox.ini b/tox.ini index 109d20222..0cd2d7873 100644 --- a/tox.ini +++ b/tox.ini @@ -118,4 +118,4 @@ exclude = ''' | src/nemos/third_party/jaxopt | __init__.py # Exclude __init__.py files ))''' -extend-ignore = W605, E203, DAR +extend-ignore = W605, E203, DAR, F722 From b80278241b55785dbd4cc1bd702d43c4b7b19dcf Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Fri, 10 Oct 2025 15:20:06 +0200 Subject: [PATCH 14/34] Change default tolerance to 1e-4 --- src/nemos/solvers/_optimistix_solvers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/solvers/_optimistix_solvers.py b/src/nemos/solvers/_optimistix_solvers.py index 0e6f42611..62865a13a 100644 --- a/src/nemos/solvers/_optimistix_solvers.py +++ b/src/nemos/solvers/_optimistix_solvers.py @@ -11,7 +11,7 @@ from ._abstract_solver import OptimizationInfo, Params from ._solver_adapter import SolverAdapter -DEFAULT_ATOL = 1e-8 +DEFAULT_ATOL = 1e-4 DEFAULT_RTOL = 0.0 DEFAULT_MAX_STEPS = 100_000 From 5a3f771bdd5c05b9b7288b7505921a79f068e7cb Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Wed, 22 Oct 2025 13:53:26 +0200 Subject: [PATCH 15/34] add todo --- tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/conftest.py b/tests/conftest.py index 5fa45db97..88c426c6b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1128,6 +1128,7 @@ def population_negativeBinomialGLM_model_instantiation_pytree( }, "optimistix": { **_common_solvers, + # TODO: OptaxOptimistixGradientDescent is not tested "GradientDescent": nmo.solvers.OptimistixNAG, "ProximalGradient": nmo.solvers.OptimistixFISTA, "LBFGS": nmo.solvers.OptimistixOptaxLBFGS, From 9b504f090ddb29e9f07494543d7e51ff25647b13 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Fri, 24 Oct 2025 17:03:58 +0200 Subject: [PATCH 16/34] Update developer notes --- docs/developers_notes/07-solvers.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/developers_notes/07-solvers.md b/docs/developers_notes/07-solvers.md index b20e54408..bbda55585 100644 --- a/docs/developers_notes/07-solvers.md +++ b/docs/developers_notes/07-solvers.md @@ -87,8 +87,13 @@ Abstract Class AbstractSolver ``` `OptaxOptimistixSolver` is an adapter for Optax solvers, relying on `optimistix.OptaxMinimiser` to run the full optimization loop. -Optimistix does not have implementations of Nesterov acceleration, so gradient descent is implemented by wrapping `optax.sgd` which does support it. -(Although what Optax calls Nesterov acceleration is not the [original method developed for convex optimization](https://hengshuaiyao.github.io/papers/nesterov83.pdf) but the [version adapted for training deep networks with SGD](https://proceedings.mlr.press/v28/sutskever13.html). JAXopt did implement the original method, and [a port of this is planned to be added to NeMoS](https://github.com/flatironinstitute/nemos/issues/380).) + +Gradient descent is implemented by two classes: +- One is wrapping `optax.sgd` which supports momentum and acceleration. +Note that what Optax calls Nesterov acceleration is not the [original method developed for convex optimization](https://hengshuaiyao.github.io/papers/nesterov83.pdf) but the [version adapted for training deep networks with SGD](https://proceedings.mlr.press/v28/sutskever13.html). +- As JAXopt implemented the original method, a [port of JAXopt's `GradientDescent` was added to NeMoS](https://github.com/flatironinstitute/nemos/pull/411) as `OptimistixNAG`. + +Similarly to NAG, an accelerated proximal gradient algorithm ([FISTA](https://www.ceremade.dauphine.fr/~carlier/FISTA)) was [ported from JAXopt](https://github.com/flatironinstitute/nemos/pull/411) as `OptimistixFISTA`. Available solvers and which implementation they dispatch to are defined in the solver registry. A list of available solvers is provided by {py:func}`nemos.solvers.list_available_solvers`, and extended documentation about each solver can be accessed using {py:func}`nemos.solvers.get_solver_documentation`. From b90c0decb801d8e4db8ea7b1499214ecdfd117ef Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 6 Nov 2025 12:27:01 +0100 Subject: [PATCH 17/34] Rename _fista_port.py to _fista.py --- src/nemos/solvers/{_fista_port.py => _fista.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/nemos/solvers/{_fista_port.py => _fista.py} (100%) diff --git a/src/nemos/solvers/_fista_port.py b/src/nemos/solvers/_fista.py similarity index 100% rename from src/nemos/solvers/_fista_port.py rename to src/nemos/solvers/_fista.py From c6271b3dab8da44c08c8bd711d8d5b11ca784856 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 6 Nov 2025 12:27:17 +0100 Subject: [PATCH 18/34] Change module docstring --- src/nemos/solvers/_fista.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/solvers/_fista.py b/src/nemos/solvers/_fista.py index d0ae2be12..3b3be85cc 100644 --- a/src/nemos/solvers/_fista.py +++ b/src/nemos/solvers/_fista.py @@ -1,4 +1,4 @@ -"""Adaptation of JAXopt's ProximalGradient (FISTA) as an Optimistix IterativeSolver.""" +"""Implementation of the FISTA algorithm as an Optimistix IterativeSolver. Adapted from JAXopt.""" from typing import Any, Callable, ClassVar From 9507ca9872a91db26cfedb5834db6cda8e24a835 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 6 Nov 2025 12:44:49 +0100 Subject: [PATCH 19/34] Typing --- src/nemos/solvers/_fista.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/solvers/_fista.py b/src/nemos/solvers/_fista.py index 3b3be85cc..9ac004178 100644 --- a/src/nemos/solvers/_fista.py +++ b/src/nemos/solvers/_fista.py @@ -15,7 +15,7 @@ # TODO: Add detailed docstrings -def prox_none(x, hyperparams=None, scaling: float = 1.0): +def prox_none(x: PyTree, hyperparams=None, scaling: float = 1.0): """Identity proximal operator.""" del hyperparams, scaling return x From f352c0a2883db009d712599a679445e3f0d5ba07 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 6 Nov 2025 13:01:28 +0100 Subject: [PATCH 20/34] Extend docstring --- src/nemos/solvers/_fista.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/src/nemos/solvers/_fista.py b/src/nemos/solvers/_fista.py index 9ac004178..cebc49834 100644 --- a/src/nemos/solvers/_fista.py +++ b/src/nemos/solvers/_fista.py @@ -12,8 +12,6 @@ from ..tree_utils import tree_add_scalar_mul, tree_sub from ._optimistix_solvers import OptimistixAdapter -# TODO: Add detailed docstrings - def prox_none(x: PyTree, hyperparams=None, scaling: float = 1.0): """Identity proximal operator.""" @@ -39,7 +37,34 @@ class ProxGradState(eqx.Module): class FISTA(optx.AbstractMinimiser[Y, Aux, ProxGradState]): """ - Accelerated Proximal Gradient (FISTA) as an Optimistix minimiser. Adapted from JAXopt. + Accelerated Proximal Gradient (FISTA) [1] as an Optimistix minimiser. Adapted from JAXopt. + + Parameters + ---------- + atol: + Absolute tolerance for Cauchy termination. + rtol: + Relative tolerance for Cauchy termination. + norm: + Norm to use in Cauchy termination. + prox: + Proximal operator function. + regularizer_strength: + Regularizer strength passed to the proximal operator. + stepsize: + If None (default), use backtracking linesearch to determine an + appropriate stepsize on each iteration. + If a float, value for a constant stepsize. + maxls: + Maximum number of linesearch iterations. + decrease_factor: + Backtracking linesearch's decrease factor. + max_stepsize: + Maximum allowed stepsize. + If None, no maximum is used. + acceleration: + Whether to use Nesterov acceleration. + References ---------- From 1e722d061bec9d781334c151273bc3296a30e44f Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 6 Nov 2025 13:16:53 +0100 Subject: [PATCH 21/34] Fix import after file rename --- src/nemos/solvers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/solvers/__init__.py b/src/nemos/solvers/__init__.py index 57c542555..94a65761f 100644 --- a/src/nemos/solvers/__init__.py +++ b/src/nemos/solvers/__init__.py @@ -1,6 +1,6 @@ """Custom solvers module.""" -from ._fista_port import OptimistixFISTA, OptimistixNAG +from ._fista import OptimistixFISTA, OptimistixNAG from ._jaxopt_solvers import ( JaxoptBFGS, JaxoptGradientDescent, From 2fc6f020a8a7a59dc1336957e2848dd521952d4a Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 6 Nov 2025 15:44:18 +0100 Subject: [PATCH 22/34] Add env vars to override solver implementation in tests --- tests/conftest.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 88c426c6b..90b125fa9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1128,7 +1128,6 @@ def population_negativeBinomialGLM_model_instantiation_pytree( }, "optimistix": { **_common_solvers, - # TODO: OptaxOptimistixGradientDescent is not tested "GradientDescent": nmo.solvers.OptimistixNAG, "ProximalGradient": nmo.solvers.OptimistixFISTA, "LBFGS": nmo.solvers.OptimistixOptaxLBFGS, @@ -1147,20 +1146,29 @@ def configure_solver_backend(): for the JAXopt and the Optimistix backends. """ backend = os.getenv("NEMOS_SOLVER_BACKEND") - if not backend: - yield # run with default solver registry - return # don't execute the remainder on teardown - try: - _backend_solver_registry = _solver_registry_per_backend[backend] - except KeyError: - available = ", ".join(_solver_registry_per_backend.keys()) - pytest.fail(f"Unknown solver backend: {backend}. Available: {available}") + if backend is None: + _solver_registry_to_use = nmo.solvers.solver_registry.copy() + else: + try: + _solver_registry_to_use = _solver_registry_per_backend[backend] + except KeyError: + available = ", ".join(_solver_registry_per_backend.keys()) + pytest.fail(f"Unknown solver backend: {backend}. Available: {available}") + + algo_name = os.getenv("NEMOS_OVERRIDE_ALGO") + impl_name = os.getenv("NEMOS_OVERRIDE_IMPL") + if algo_name and impl_name: + _solver_registry_to_use[algo_name] = getattr(nmo.solvers, impl_name) + elif algo_name or impl_name: + raise ValueError( + "Either both NEMOS_OVERRIDE_ALGO and NEMOS_OVERRIDE_IMPL have to be set or neither." + ) # save the original registry so that we can restore it after original = nmo.solvers.solver_registry.copy() nmo.solvers.solver_registry.clear() - nmo.solvers.solver_registry.update(_backend_solver_registry) + nmo.solvers.solver_registry.update(_solver_registry_to_use) try: yield From 1f40c250c609c733c2c0b2d130f51bda85e4d54e Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 6 Nov 2025 15:44:31 +0100 Subject: [PATCH 23/34] Run subset of tests with OptimistixOptaxGradientDescent --- .github/workflows/ci.yml | 3 +++ tox.ini | 15 ++++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8ec1a3b34..dfeb2dd54 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -104,6 +104,9 @@ jobs: - name: Run solver-dependent tests with Optimistix backend run: tox -e backend-optimistix + - name: Rerun tests using GradientDescent with the OptimistixOptaxGradientDescent implementation + run: tox -e backend-optimistix-optax-gd + - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v4.0.1 env: diff --git a/tox.ini b/tox.ini index 0cd2d7873..ea426ac07 100644 --- a/tox.ini +++ b/tox.ini @@ -50,7 +50,20 @@ commands = description = Run solver tests and doctests with the Optimistix backend setenv = NEMOS_SOLVER_BACKEND = optimistix -commands = {[testenv:backend-jaxopt]commands} +commands = + {[testenv:backend-jaxopt]commands} + +# can be used as a template to run tests overriding a single solver implementation +[testenv:backend-optimistix-optax-gd] +description = Run tests using GradientDescent with the OptimistixOptaxGradientDescent implementation +setenv = + NEMOS_SOLVER_BACKEND = optimistix + NEMOS_OVERRIDE_ALGO = GradientDescent + NEMOS_OVERRIDE_IMPL = OptimistixOptaxGradientDescent +commands = + pytest -n auto --doctest-modules src/nemos/solvers -k {env:NEMOS_OVERRIDE_IMPL} + pytest -n auto -m solver_related -k {env:NEMOS_OVERRIDE_ALGO} + [testenv:fix] From 991bd84060570c8ac61095635b5b60d437a82b84 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 6 Nov 2025 15:52:39 +0100 Subject: [PATCH 24/34] Remove done TODO --- src/nemos/solvers/_optax_optimistix_solvers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/nemos/solvers/_optax_optimistix_solvers.py b/src/nemos/solvers/_optax_optimistix_solvers.py index cf9168658..36967b0f0 100644 --- a/src/nemos/solvers/_optax_optimistix_solvers.py +++ b/src/nemos/solvers/_optax_optimistix_solvers.py @@ -65,7 +65,6 @@ def __init_subclass__(cls, **kwargs): cls.__doc__ = inspect.cleandoc(full_doc) -# TODO: Figure out how to test both OptimistixOptaxGradientDescent and OptimistixNAG def _make_rate_scaler( stepsize: float | None, linesearch_kwargs: dict[str, Any] | None, From 04c91ebb00dae86b3e799e481e1f6c796695d8d9 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 6 Nov 2025 19:07:19 +0100 Subject: [PATCH 25/34] Test both GD implementations in the same tox env --- .github/workflows/ci.yml | 3 --- tests/conftest.py | 23 +++++++++++++++-------- tox.ini | 18 ++++-------------- 3 files changed, 19 insertions(+), 25 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dfeb2dd54..8ec1a3b34 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -104,9 +104,6 @@ jobs: - name: Run solver-dependent tests with Optimistix backend run: tox -e backend-optimistix - - name: Rerun tests using GradientDescent with the OptimistixOptaxGradientDescent implementation - run: tox -e backend-optimistix-optax-gd - - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v4.0.1 env: diff --git a/tests/conftest.py b/tests/conftest.py index 90b125fa9..72f1438fd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1138,7 +1138,7 @@ def population_negativeBinomialGLM_model_instantiation_pytree( @pytest.fixture(autouse=True, scope="session") -def configure_solver_backend(): +def configure_solver_backend(request): """ Patch the solver registry depending on ``NEMOS_SOLVER_BACKEND``. @@ -1156,14 +1156,15 @@ def configure_solver_backend(): available = ", ".join(_solver_registry_per_backend.keys()) pytest.fail(f"Unknown solver backend: {backend}. Available: {available}") - algo_name = os.getenv("NEMOS_OVERRIDE_ALGO") - impl_name = os.getenv("NEMOS_OVERRIDE_IMPL") - if algo_name and impl_name: + override_solver = request.config.getini("override_solver") + if override_solver: + try: + algo_name, impl_name = override_solver.split(":", 1) + except ValueError: + raise ValueError( + f"override_solver must be in format 'algo:implementation', got: {override_solver}" + ) _solver_registry_to_use[algo_name] = getattr(nmo.solvers, impl_name) - elif algo_name or impl_name: - raise ValueError( - "Either both NEMOS_OVERRIDE_ALGO and NEMOS_OVERRIDE_IMPL have to be set or neither." - ) # save the original registry so that we can restore it after original = nmo.solvers.solver_registry.copy() @@ -1175,3 +1176,9 @@ def configure_solver_backend(): finally: nmo.solvers.solver_registry.clear() nmo.solvers.solver_registry.update(original) + + +def pytest_addoption(parser): + """Register custom ini options.""" + parser.addini("solver_backend", "Solver backend to use") + parser.addini("override_solver", "Override solver as 'algorithm:implementation'") diff --git a/tox.ini b/tox.ini index ea426ac07..e871f17de 100644 --- a/tox.ini +++ b/tox.ini @@ -46,25 +46,15 @@ commands = pytest -n auto --doctest-modules src/nemos/solvers pytest -n auto -m solver_related +# {[testenv:backend-jaxopt]commands} [testenv:backend-optimistix] description = Run solver tests and doctests with the Optimistix backend +allowlist_externals = env setenv = NEMOS_SOLVER_BACKEND = optimistix commands = - {[testenv:backend-jaxopt]commands} - -# can be used as a template to run tests overriding a single solver implementation -[testenv:backend-optimistix-optax-gd] -description = Run tests using GradientDescent with the OptimistixOptaxGradientDescent implementation -setenv = - NEMOS_SOLVER_BACKEND = optimistix - NEMOS_OVERRIDE_ALGO = GradientDescent - NEMOS_OVERRIDE_IMPL = OptimistixOptaxGradientDescent -commands = - pytest -n auto --doctest-modules src/nemos/solvers -k {env:NEMOS_OVERRIDE_IMPL} - pytest -n auto -m solver_related -k {env:NEMOS_OVERRIDE_ALGO} - - + pytest -n auto --doctest-modules src/nemos/solvers -k OptimistixOptaxGradientDescent -o override_solver=GradientDescent:OptimistixOptaxGradientDescent + pytest -n auto -m solver_related -k GradientDescent -o override_solver=GradientDescent:OptimistixOptaxGradientDescent [testenv:fix] commands= From f59d35f71c89f718261db90a20002c8e38a56b79 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 6 Nov 2025 19:12:48 +0100 Subject: [PATCH 26/34] Fix and finish previous commit --- tests/conftest.py | 9 ++++++--- tox.ini | 5 +++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 72f1438fd..12654f616 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1140,10 +1140,13 @@ def population_negativeBinomialGLM_model_instantiation_pytree( @pytest.fixture(autouse=True, scope="session") def configure_solver_backend(request): """ - Patch the solver registry depending on ``NEMOS_SOLVER_BACKEND``. + Patch the solver registry depending on `NEMOS_SOLVER_BACKEND` and `override_solver`. - Used for running solver-dependent tests in separate tox environments - for the JAXopt and the Optimistix backends. + The `NEMOS_SOLVER_BACKEND` env variable is used for running solver-dependent tests + in separate tox environments for the JAXopt and the Optimistix backends. + + The `override_solver` pytest option is used to set a given solver algorithm's + implementation to a class available in nemos.solvers. """ backend = os.getenv("NEMOS_SOLVER_BACKEND") diff --git a/tox.ini b/tox.ini index e871f17de..6b9d900ab 100644 --- a/tox.ini +++ b/tox.ini @@ -46,13 +46,14 @@ commands = pytest -n auto --doctest-modules src/nemos/solvers pytest -n auto -m solver_related -# {[testenv:backend-jaxopt]commands} +# -k filters tests based on keyword [testenv:backend-optimistix] -description = Run solver tests and doctests with the Optimistix backend +description = Run solver tests and doctests with the Optimistix backend using both GradientDescent implementations allowlist_externals = env setenv = NEMOS_SOLVER_BACKEND = optimistix commands = + {[testenv:backend-jaxopt]commands} pytest -n auto --doctest-modules src/nemos/solvers -k OptimistixOptaxGradientDescent -o override_solver=GradientDescent:OptimistixOptaxGradientDescent pytest -n auto -m solver_related -k GradientDescent -o override_solver=GradientDescent:OptimistixOptaxGradientDescent From 0f76e18089772c43c2c37c53047861a05ce17515 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Thu, 6 Nov 2025 19:13:21 +0100 Subject: [PATCH 27/34] Another fix: remove unnecessary line --- tox.ini | 1 - 1 file changed, 1 deletion(-) diff --git a/tox.ini b/tox.ini index 6b9d900ab..0985a52c8 100644 --- a/tox.ini +++ b/tox.ini @@ -49,7 +49,6 @@ commands = # -k filters tests based on keyword [testenv:backend-optimistix] description = Run solver tests and doctests with the Optimistix backend using both GradientDescent implementations -allowlist_externals = env setenv = NEMOS_SOLVER_BACKEND = optimistix commands = From 8224834dbbea75461e1eaa2e817ac5ddf6b5e32a Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Fri, 7 Nov 2025 15:09:33 +0100 Subject: [PATCH 28/34] Pass solver arguments whose name is also in OptimistixConfig --- src/nemos/solvers/_optimistix_solvers.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/nemos/solvers/_optimistix_solvers.py b/src/nemos/solvers/_optimistix_solvers.py index 62865a13a..e5c3d3416 100644 --- a/src/nemos/solvers/_optimistix_solvers.py +++ b/src/nemos/solvers/_optimistix_solvers.py @@ -1,4 +1,5 @@ import dataclasses +import inspect from typing import Any, Callable, ClassVar, Type, TypeAlias import equinox as eqx @@ -110,11 +111,17 @@ def __init__( # take out the arguments that go into minimise, init, terminate and so on # and only pass the actually needed things to __init__ + solver_init_param_names = set( + inspect.getfullargspec(self._solver_cls.__init__).args + ) user_args = {} for f in dataclasses.fields(OptimistixConfig): kw = f.name if kw in solver_init_kwargs: - user_args[kw] = solver_init_kwargs.pop(kw) + if kw in solver_init_param_names: + user_args[kw] = solver_init_kwargs[kw] + else: + user_args[kw] = solver_init_kwargs.pop(kw) self.config = OptimistixConfig(maxiter=maxiter, **user_args) self._solver = self._solver_cls( From 242ce1846f7044c4b890ae6008ecbe88d4640449 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Fri, 7 Nov 2025 15:10:00 +0100 Subject: [PATCH 29/34] First solution for the kind argument to the linesearch's while loop --- src/nemos/solvers/_fista.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/nemos/solvers/_fista.py b/src/nemos/solvers/_fista.py index cebc49834..ef1c48d62 100644 --- a/src/nemos/solvers/_fista.py +++ b/src/nemos/solvers/_fista.py @@ -88,6 +88,8 @@ class FISTA(optx.AbstractMinimiser[Y, Aux, ProxGradState]): acceleration: bool = True + adjoint: optx.AbstractAdjoint | None = None + def init( self, fn: Callable, @@ -297,18 +299,25 @@ def body_fun(carry): init_x = self.prox(init_x, self.regularizer_strength, stepsize) init_val = (init_x, stepsize) - # TODO: make kind dependent on the adjoint used? - # "lax" for implicit, "checkpointed" for RecursiveCheckpointAdjoint - # could make a partial function where these are available - # or just accept it in __init__? return eqx.internal.while_loop( cond_fun=cond_fun, body_fun=body_fun, init_val=init_val, max_steps=self.maxls, - kind="lax", + kind=self.while_loop_kind, ) + @property + def while_loop_kind(self): + """Determine `kind` argument to the linesearch's while_loop.""" + kind = "bounded" + if isinstance(self.adjoint, optx.RecursiveCheckpointAdjoint): + kind = "checkpointed" + if isinstance(self.adjoint, optx.ImplicitAdjoint): + kind = "lax" + + return kind + def terminate( self, fn: Callable, From a326540cf692339449ea611bf2722e3ff5605325 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Fri, 7 Nov 2025 15:37:00 +0100 Subject: [PATCH 30/34] Revert "First solution for the kind argument to the linesearch's while loop" This reverts commit 242ce1846f7044c4b890ae6008ecbe88d4640449. --- src/nemos/solvers/_fista.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/nemos/solvers/_fista.py b/src/nemos/solvers/_fista.py index ef1c48d62..cebc49834 100644 --- a/src/nemos/solvers/_fista.py +++ b/src/nemos/solvers/_fista.py @@ -88,8 +88,6 @@ class FISTA(optx.AbstractMinimiser[Y, Aux, ProxGradState]): acceleration: bool = True - adjoint: optx.AbstractAdjoint | None = None - def init( self, fn: Callable, @@ -299,25 +297,18 @@ def body_fun(carry): init_x = self.prox(init_x, self.regularizer_strength, stepsize) init_val = (init_x, stepsize) + # TODO: make kind dependent on the adjoint used? + # "lax" for implicit, "checkpointed" for RecursiveCheckpointAdjoint + # could make a partial function where these are available + # or just accept it in __init__? return eqx.internal.while_loop( cond_fun=cond_fun, body_fun=body_fun, init_val=init_val, max_steps=self.maxls, - kind=self.while_loop_kind, + kind="lax", ) - @property - def while_loop_kind(self): - """Determine `kind` argument to the linesearch's while_loop.""" - kind = "bounded" - if isinstance(self.adjoint, optx.RecursiveCheckpointAdjoint): - kind = "checkpointed" - if isinstance(self.adjoint, optx.ImplicitAdjoint): - kind = "lax" - - return kind - def terminate( self, fn: Callable, From 2206d34494d42e74ef28d49b8bb58d87e4cd234e Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Fri, 7 Nov 2025 15:37:12 +0100 Subject: [PATCH 31/34] Revert "Pass solver arguments whose name is also in OptimistixConfig" This reverts commit 8224834dbbea75461e1eaa2e817ac5ddf6b5e32a. --- src/nemos/solvers/_optimistix_solvers.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/nemos/solvers/_optimistix_solvers.py b/src/nemos/solvers/_optimistix_solvers.py index e5c3d3416..62865a13a 100644 --- a/src/nemos/solvers/_optimistix_solvers.py +++ b/src/nemos/solvers/_optimistix_solvers.py @@ -1,5 +1,4 @@ import dataclasses -import inspect from typing import Any, Callable, ClassVar, Type, TypeAlias import equinox as eqx @@ -111,17 +110,11 @@ def __init__( # take out the arguments that go into minimise, init, terminate and so on # and only pass the actually needed things to __init__ - solver_init_param_names = set( - inspect.getfullargspec(self._solver_cls.__init__).args - ) user_args = {} for f in dataclasses.fields(OptimistixConfig): kw = f.name if kw in solver_init_kwargs: - if kw in solver_init_param_names: - user_args[kw] = solver_init_kwargs[kw] - else: - user_args[kw] = solver_init_kwargs.pop(kw) + user_args[kw] = solver_init_kwargs.pop(kw) self.config = OptimistixConfig(maxiter=maxiter, **user_args) self._solver = self._solver_cls( From deda0d8f421c32530616ac5d868c71b9a56241ca Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Tue, 11 Nov 2025 14:51:02 +0100 Subject: [PATCH 32/34] Another solution for deriving "kind" for the while loop --- src/nemos/solvers/_fista.py | 27 ++++++++++++++++++------ src/nemos/solvers/_optimistix_solvers.py | 22 ++++++++++++++++--- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/src/nemos/solvers/_fista.py b/src/nemos/solvers/_fista.py index cebc49834..2a8ac47e4 100644 --- a/src/nemos/solvers/_fista.py +++ b/src/nemos/solvers/_fista.py @@ -1,6 +1,6 @@ """Implementation of the FISTA algorithm as an Optimistix IterativeSolver. Adapted from JAXopt.""" -from typing import Any, Callable, ClassVar +from typing import Any, Callable, ClassVar, Literal import equinox as eqx import jax @@ -10,7 +10,7 @@ from optimistix._custom_types import Aux, Y from ..tree_utils import tree_add_scalar_mul, tree_sub -from ._optimistix_solvers import OptimistixAdapter +from ._optimistix_solvers import OptimistixAdapter, OptimistixConfig def prox_none(x: PyTree, hyperparams=None, scaling: float = 1.0): @@ -88,6 +88,8 @@ class FISTA(optx.AbstractMinimiser[Y, Aux, ProxGradState]): acceleration: bool = True + while_loop_kind: Literal["lax", "checkpointed", "bounded"] | None = None + def init( self, fn: Callable, @@ -297,16 +299,12 @@ def body_fun(carry): init_x = self.prox(init_x, self.regularizer_strength, stepsize) init_val = (init_x, stepsize) - # TODO: make kind dependent on the adjoint used? - # "lax" for implicit, "checkpointed" for RecursiveCheckpointAdjoint - # could make a partial function where these are available - # or just accept it in __init__? return eqx.internal.while_loop( cond_fun=cond_fun, body_fun=body_fun, init_val=init_val, max_steps=self.maxls, - kind="lax", + kind=self.while_loop_kind, ) def terminate( @@ -350,9 +348,24 @@ class OptimistixFISTA(OptimistixAdapter): _solver_cls = FISTA _proximal = True + def _params_derived_from_config(self) -> dict: + """Derive the "kind" parameter of the linesearch's while_loop based on the adjoint.""" + if isinstance(self.config.adjoint, optx.ImplicitAdjoint): + kind = "lax" + elif isinstance(self.config.adjoint, optx.RecursiveCheckpointAdjoint): + kind = "checkpointed" + else: + raise ValueError( + "adjoint has to be ImplicitAdjoint or RecursiveCheckpointAdjoint" + ) + + return {"while_loop_kind": kind} + class OptimistixNAG(OptimistixAdapter): """Port of Nesterov's accelerated gradient descent from JAXopt to the Optimistix API.""" _solver_cls = GradientDescent _proximal = False + + _params_derived_from_config = OptimistixFISTA._params_derived_from_config diff --git a/src/nemos/solvers/_optimistix_solvers.py b/src/nemos/solvers/_optimistix_solvers.py index 62865a13a..3c992d1cb 100644 --- a/src/nemos/solvers/_optimistix_solvers.py +++ b/src/nemos/solvers/_optimistix_solvers.py @@ -122,6 +122,7 @@ def __init__( rtol=rtol, norm=self.config.norm, **solver_init_kwargs, + **self._params_derived_from_config(), ) self.stats = {} @@ -209,9 +210,7 @@ def maxiter(self) -> int: def get_optim_info(self, state: OptimistixSolverState) -> OptimizationInfo: num_steps = self.stats["num_steps"].item() - function_val = ( - state.f.item() if hasattr(state, "f") else state.f_info.f.item() - ) # pyright: ignore + function_val = state.f.item() if hasattr(state, "f") else state.f_info.f.item() # pyright: ignore return OptimizationInfo( function_val=function_val, @@ -220,6 +219,23 @@ def get_optim_info(self, state: OptimistixSolverState) -> OptimizationInfo: reached_max_steps=(num_steps == self.maxiter), ) + def _params_derived_from_config(self) -> dict: + """ + Optionally derive some parameters for instantiating the wrapped solver. + + Parameters + ---------- + config: + OptimistixConfig created in the constructor. + + Returns + ------- + dict with argument names of _solver_cls.__init__ as keys and + their corresponding values as values. + Default implementation returns an empty dict. + """ + return {} + class OptimistixBFGS(OptimistixAdapter): """Adapter for optimistix.BFGS.""" From 9686476e4f7d4724ad2b319f3fa5501a61230fe7 Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Tue, 11 Nov 2025 15:01:43 +0100 Subject: [PATCH 33/34] Mofify solver_init_kwargs instead of a separate dict for derived params --- src/nemos/solvers/_fista.py | 8 +++++--- src/nemos/solvers/_optimistix_solvers.py | 18 +++++++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/nemos/solvers/_fista.py b/src/nemos/solvers/_fista.py index 2a8ac47e4..ef8e3fac0 100644 --- a/src/nemos/solvers/_fista.py +++ b/src/nemos/solvers/_fista.py @@ -348,7 +348,9 @@ class OptimistixFISTA(OptimistixAdapter): _solver_cls = FISTA _proximal = True - def _params_derived_from_config(self) -> dict: + def adjust_solver_init_kwargs( + self, solver_init_kwargs: dict[str, Any] + ) -> dict[str, Any]: """Derive the "kind" parameter of the linesearch's while_loop based on the adjoint.""" if isinstance(self.config.adjoint, optx.ImplicitAdjoint): kind = "lax" @@ -359,7 +361,7 @@ def _params_derived_from_config(self) -> dict: "adjoint has to be ImplicitAdjoint or RecursiveCheckpointAdjoint" ) - return {"while_loop_kind": kind} + return {"while_loop_kind": kind, **solver_init_kwargs} class OptimistixNAG(OptimistixAdapter): @@ -368,4 +370,4 @@ class OptimistixNAG(OptimistixAdapter): _solver_cls = GradientDescent _proximal = False - _params_derived_from_config = OptimistixFISTA._params_derived_from_config + adjust_solver_init_kwargs = OptimistixFISTA.adjust_solver_init_kwargs diff --git a/src/nemos/solvers/_optimistix_solvers.py b/src/nemos/solvers/_optimistix_solvers.py index 3c992d1cb..89a8fb71f 100644 --- a/src/nemos/solvers/_optimistix_solvers.py +++ b/src/nemos/solvers/_optimistix_solvers.py @@ -117,12 +117,14 @@ def __init__( user_args[kw] = solver_init_kwargs.pop(kw) self.config = OptimistixConfig(maxiter=maxiter, **user_args) + # make custom adjustments such as adding a derived "while_loop_kind" parameter for FISTA + solver_init_kwargs = self.adjust_solver_init_kwargs(solver_init_kwargs) + self._solver = self._solver_cls( atol=tol, rtol=rtol, norm=self.config.norm, **solver_init_kwargs, - **self._params_derived_from_config(), ) self.stats = {} @@ -219,22 +221,24 @@ def get_optim_info(self, state: OptimistixSolverState) -> OptimizationInfo: reached_max_steps=(num_steps == self.maxiter), ) - def _params_derived_from_config(self) -> dict: + def adjust_solver_init_kwargs( + self, solver_init_kwargs: dict[str, Any] + ) -> dict[str, Any]: """ - Optionally derive some parameters for instantiating the wrapped solver. + Optionally adjust the parameters (e.g. derive from self.config) for instantiating the wrapped solver. Parameters ---------- - config: - OptimistixConfig created in the constructor. + solver_init_kwargs: + Original keyword arguments that would be passed to _solver_cls.__init__. Returns ------- dict with argument names of _solver_cls.__init__ as keys and their corresponding values as values. - Default implementation returns an empty dict. + Default implementation just returns solver_init_kwargs. """ - return {} + return solver_init_kwargs class OptimistixBFGS(OptimistixAdapter): From 7c06e7b5bd5d0349542d5e2e980fa7db3075c40c Mon Sep 17 00:00:00 2001 From: Bence Bagi Date: Mon, 24 Nov 2025 13:18:09 +0100 Subject: [PATCH 34/34] Add tests and remove unused import --- src/nemos/solvers/_fista.py | 2 +- tests/test_fista_adjoint.py | 156 ++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 tests/test_fista_adjoint.py diff --git a/src/nemos/solvers/_fista.py b/src/nemos/solvers/_fista.py index ef8e3fac0..ac936bd3c 100644 --- a/src/nemos/solvers/_fista.py +++ b/src/nemos/solvers/_fista.py @@ -10,7 +10,7 @@ from optimistix._custom_types import Aux, Y from ..tree_utils import tree_add_scalar_mul, tree_sub -from ._optimistix_solvers import OptimistixAdapter, OptimistixConfig +from ._optimistix_solvers import OptimistixAdapter def prox_none(x: PyTree, hyperparams=None, scaling: float = 1.0): diff --git a/tests/test_fista_adjoint.py b/tests/test_fista_adjoint.py new file mode 100644 index 000000000..c3b162f3d --- /dev/null +++ b/tests/test_fista_adjoint.py @@ -0,0 +1,156 @@ +import os + +import optimistix as optx +import pytest + +import nemos as nmo + +# everything is solver-related here +pytestmark = pytest.mark.solver_related + + +@pytest.fixture +def optimistix_solver_registry(monkeypatch): + """Point GLM solver registry at the Optimistix implementations for this module.""" + registry = nmo.solvers.solver_registry.copy() + optimistix_registry = registry | { + "GradientDescent": nmo.solvers.OptimistixNAG, + "ProximalGradient": nmo.solvers.OptimistixFISTA, + } + monkeypatch.setattr(nmo.solvers, "solver_registry", optimistix_registry) + return optimistix_registry + + +@pytest.mark.parametrize( + "adjoint", + [optx.ImplicitAdjoint(), optx.RecursiveCheckpointAdjoint()], +) +@pytest.mark.parametrize( + "solver_name", + ["GradientDescent", "ProximalGradient"], +) +@pytest.mark.skipif( + os.getenv("NEMOS_SOLVER_BACKEND") != "optimistix", + reason="Only run with the Optimistix backend", +) +def test_glm_passes_adjoint_to_optimistix_config( + optimistix_solver_registry, adjoint, solver_name +): + glm = nmo.glm.GLM( + regularizer="Ridge", + regularizer_strength=0.1, + solver_name=solver_name, + solver_kwargs={"adjoint": adjoint}, + ) + glm.instantiate_solver() + + solver_adapter = glm._solver + assert isinstance(solver_adapter.config.adjoint, type(adjoint)) + + # not true because GLM.instantiate_solver does a deepcopy + # assert solver_adapter.config.adjoint is adjoint + + +@pytest.mark.parametrize( + ("adjoint", "expected_kind"), + [ + (optx.ImplicitAdjoint(), "lax"), + (optx.RecursiveCheckpointAdjoint(), "checkpointed"), + ], +) +@pytest.mark.parametrize( + "solver_name", + ["GradientDescent", "ProximalGradient"], +) +@pytest.mark.skipif( + os.getenv("NEMOS_SOLVER_BACKEND") != "optimistix", + reason="Only run with the Optimistix backend", +) +def test_fista_while_loop_kind_matches_adjoint( + optimistix_solver_registry, adjoint, expected_kind, solver_name +): + glm = nmo.glm.GLM( + regularizer="Ridge", + regularizer_strength=0.1, + solver_name=solver_name, + solver_kwargs={"adjoint": adjoint}, + ) + glm.instantiate_solver() + + fista_solver = glm._solver._solver + assert fista_solver.while_loop_kind == expected_kind + + +@pytest.mark.parametrize( + ("adjoint", "while_loop_kind"), + [ + (optx.ImplicitAdjoint(), "bounded"), + (optx.RecursiveCheckpointAdjoint(), "lax"), + ], +) +@pytest.mark.parametrize( + "solver_name", + ["GradientDescent", "ProximalGradient"], +) +@pytest.mark.skipif( + os.getenv("NEMOS_SOLVER_BACKEND") != "optimistix", + reason="Only run with the Optimistix backend", +) +def test_fista_explicit_while_loop_kind_overrides_adjoint( + optimistix_solver_registry, adjoint, while_loop_kind, solver_name +): + glm = nmo.glm.GLM( + regularizer="Ridge", + regularizer_strength=0.1, + solver_name=solver_name, + solver_kwargs={"adjoint": adjoint, "while_loop_kind": while_loop_kind}, + ) + glm.instantiate_solver() + + fista_solver = glm._solver._solver + assert fista_solver.while_loop_kind == while_loop_kind + + +@pytest.mark.parametrize( + "model_instantiation_type", + [ + "poissonGLM_model_instantiation", + # "population_poissonGLM_model_instantiation", + ], +) +@pytest.mark.parametrize( + "adjoint", [optx.ImplicitAdjoint(), optx.RecursiveCheckpointAdjoint()] +) +# @pytest.mark.parametrize("while_loop_kind", ["lax", "checkpointed", "bounded"]) +@pytest.mark.parametrize("while_loop_kind", ["bounded", "lax", "checkpointed"]) +@pytest.mark.skipif( + os.getenv("NEMOS_SOLVER_BACKEND") != "optimistix", + reason="Only run with the Optimistix backend", +) +def test_fit_succeeds_with_mismatched_adjoint_and_while_loop_kind( + optimistix_solver_registry, + request, + model_instantiation_type, + adjoint, + while_loop_kind, +): + data = request.getfixturevalue(model_instantiation_type) + X, y = data[:2] + model = data[2] + + # explicitly pass a while_loop_kind that does not match the adjoint-derived default + model.set_params( + solver_name="ProximalGradient", + solver_kwargs={ + "adjoint": adjoint, + "while_loop_kind": while_loop_kind, + "maxiter": 5, + }, + ) + model.fit(X, y) + + solver_adapter = model._solver + assert isinstance(solver_adapter.config.adjoint, type(adjoint)) + assert solver_adapter._solver.while_loop_kind == while_loop_kind + + assert model._get_fit_state()["solver_state_"] is not None