Skip to content

Conversation

@bagibence
Copy link
Collaborator

Problem stated in #375:

Optimization libraries support objective functions that return not just a scalar function value, but also auxiliary data. These are mostly used for logging, diagnostics, and debugging.
Currently, the objective of GLMs does not have any aux, and throughout the codebase (e.g. in the AbstractSolver interface) nemos assumes that objective functions return a scalar only.

As other models might make use of aux, this PR prepares the solver interface and current models to deal with that.

Main changes:

  • (Prox-)SVRG can now handle objectives with aux. Following JAXopt solvers, aux is saved in the solver state. This saved aux does not come from the last evaluation of the objective or its gradient -- which is done on a minibatch -- but is the result of evaluating the gradient on the full data at the last reference point.
  • BaseRegressor has a class attribute called has_aux that is passed to the solver on instantiation. Models whose objective returns aux have to overwrite this.
  • Accordingly, has_aux is now a required argument of AbstractSolver.__init__.
  • Adapt tests to these changes.

Remaining questions, tasks:

  • In order to avoid braking existing code, GLM.update doesn't return aux but saves it in self.aux. Should this be named self.aux_ instead? Or returned?
  • GLM.update performed the update, then estimated the scale using the previous parameters. Was that intended? I changed it to use the new parameters.
  • has_aux is the only class attribute of BaseRegressor. Is that okay or should it be stored somewhere else? In any case, I will add a note about it to the developer notes.

Fixes #375

@bagibence bagibence marked this pull request as ready for review November 24, 2025 09:42
@BalzaniEdoardo
Copy link
Collaborator

Remaining questions, tasks:

* In order to avoid braking existing code, GLM.update doesn't return aux but saves it in `self.aux`. Should this be named `self.aux_` instead? Or returned?

I would call it self.aux_

* `GLM.update` performed the update, then estimated the scale using the previous parameters. Was that intended? I changed it to use the new parameters.

good catch

* `has_aux` is the only class attribute of `BaseRegressor`. Is that okay or should it be stored somewhere else? In any case, I will add a note about it to the developer notes.

I think that's ok.

Copy link
Collaborator

@BalzaniEdoardo BalzaniEdoardo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a few comments but all minor stuff

# the output of loss. I believe it's the output of
# solver.l2_optimality_error
self.solver_state_ = state
# TODO: Should this be part of fit-state, so called aux_?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably also part of the fit state so that people running a loop externally, can concatenate the aux across iterations? if there is some metric that users may want to track, it would make it easier

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean solver state? By fit-state I meant what is extracted by BaseRegressor._get_fit_state

f_struct is "the shape+dtype of the output of `fn`".
aux_struct is the same for the returned aux.
"""
y0 = jax.tree_util.tree_map(optx._misc.inexact_asarray, y0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this inexact_asarray do? and why are we using a private function? is there a public equivalent?

y0 = jax.tree_util.tree_map(optx._misc.inexact_asarray, y0)
if not has_aux:
fn = optx._misc.NoneAux(fn) # pyright: ignore
fn = optx._misc.OutAsArray(fn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These _misc functions are tiny wrappers - since they're internal utilities, we should consider porting them directly instead of relying on optx._misc for maintainability.

Specifically:

  • The wrapper module has a __call__ that returns (func(x), None)
  • inexact_asarray is just jnp.asarray(x) which converts numeric scalars to arrays at default float precision (per JAX docs: "all numeric scalar types with a (potentially) inexact representation")

Since these are so small, it would be more maintainable to inline them rather than depend on private Optimistix APIs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probaly OutAsArray is also another small one

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. I'm not sure yet. They're all small, but they call a bunch of other ones that would have to be ported as well, and it adds up.
For now I removed the dependency and made my own little wrappers for the absolutely necessary ones.
I'll look into if and why the others are need.

self.fun = lambda params, args: loss_fn(params, *args)[0]
else:
self.fun = lambda params, args: loss_fn(params, *args)
self.fun_with_aux = lambda params, args: (loss_fn(params, *args), None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can directly use optx._misc.NoneAux or its port if we have it. this is the same call used inside the function that returns the srcut of f and aux, so at this point we can have both matching

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the locally defined _wrap_aux now.

prev_reference_point, *args
)
full_grad_at_reference_point=full_grad,
aux=new_aux,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, so we have the aux as part of the state. is that true for every solver already?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is true for all jaxopt and optmistix solvers we have

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Optax doesn't support auxiliary variables and so optimistix.OptaxMinimiser, and with that OptimistixOptaxGradientDescent and OptimistixOptaxLBFGS, don't have it in the state.

def run(self, init_params: Params, *args: Any) -> JaxoptStepResult:
return self._solver.run(init_params, *self.hyperparams_prox, *args)
params, state = self._solver.run(init_params, *self.hyperparams_prox, *args)
return (params, state, state.aux)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, is there a reason for returning aux explicitly in solver.run and solver.update? As opposed to returning just params and state (following the internal _solver API) and then setting self.aux_ = opt_state.aux in glm fit and update.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would work great if we knew that opt_state.aux exists for sure, but not every solver is guaranteed to store it in the state.
It's also consistent with how step works in Optimistix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants