-
Notifications
You must be signed in to change notification settings - Fork 13
Support objective functions with auxiliary variables #444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: development
Are you sure you want to change the base?
Conversation
Child classes whose objective function returns aux will have to set this to true.
There was a potential bug where at the end of GLM.update self.scale_ was estimated using the old params.
Update the test for it as well.
I would call it self.aux_
good catch
I think that's ok. |
BalzaniEdoardo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a few comments but all minor stuff
src/nemos/glm/glm.py
Outdated
| # the output of loss. I believe it's the output of | ||
| # solver.l2_optimality_error | ||
| self.solver_state_ = state | ||
| # TODO: Should this be part of fit-state, so called aux_? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably also part of the fit state so that people running a loop externally, can concatenate the aux across iterations? if there is some metric that users may want to track, it would make it easier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean solver state? By fit-state I meant what is extracted by BaseRegressor._get_fit_state
| f_struct is "the shape+dtype of the output of `fn`". | ||
| aux_struct is the same for the returned aux. | ||
| """ | ||
| y0 = jax.tree_util.tree_map(optx._misc.inexact_asarray, y0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this inexact_asarray do? and why are we using a private function? is there a public equivalent?
| y0 = jax.tree_util.tree_map(optx._misc.inexact_asarray, y0) | ||
| if not has_aux: | ||
| fn = optx._misc.NoneAux(fn) # pyright: ignore | ||
| fn = optx._misc.OutAsArray(fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These _misc functions are tiny wrappers - since they're internal utilities, we should consider porting them directly instead of relying on optx._misc for maintainability.
Specifically:
- The wrapper module has a
__call__that returns(func(x), None) inexact_asarrayis justjnp.asarray(x)which converts numeric scalars to arrays at default float precision (per JAX docs: "all numeric scalar types with a (potentially) inexact representation")
Since these are so small, it would be more maintainable to inline them rather than depend on private Optimistix APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probaly OutAsArray is also another small one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point. I'm not sure yet. They're all small, but they call a bunch of other ones that would have to be ported as well, and it adds up.
For now I removed the dependency and made my own little wrappers for the absolutely necessary ones.
I'll look into if and why the others are need.
| self.fun = lambda params, args: loss_fn(params, *args)[0] | ||
| else: | ||
| self.fun = lambda params, args: loss_fn(params, *args) | ||
| self.fun_with_aux = lambda params, args: (loss_fn(params, *args), None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can directly use optx._misc.NoneAux or its port if we have it. this is the same call used inside the function that returns the srcut of f and aux, so at this point we can have both matching
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the locally defined _wrap_aux now.
| prev_reference_point, *args | ||
| ) | ||
| full_grad_at_reference_point=full_grad, | ||
| aux=new_aux, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, so we have the aux as part of the state. is that true for every solver already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is true for all jaxopt and optmistix solvers we have
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so. Optax doesn't support auxiliary variables and so optimistix.OptaxMinimiser, and with that OptimistixOptaxGradientDescent and OptimistixOptaxLBFGS, don't have it in the state.
| def run(self, init_params: Params, *args: Any) -> JaxoptStepResult: | ||
| return self._solver.run(init_params, *self.hyperparams_prox, *args) | ||
| params, state = self._solver.run(init_params, *self.hyperparams_prox, *args) | ||
| return (params, state, state.aux) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, is there a reason for returning aux explicitly in solver.run and solver.update? As opposed to returning just params and state (following the internal _solver API) and then setting self.aux_ = opt_state.aux in glm fit and update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would work great if we knew that opt_state.aux exists for sure, but not every solver is guaranteed to store it in the state.
It's also consistent with how step works in Optimistix.
Problem stated in #375:
Optimization libraries support objective functions that return not just a scalar function value, but also auxiliary data. These are mostly used for logging, diagnostics, and debugging.
Currently, the objective of GLMs does not have any aux, and throughout the codebase (e.g. in the AbstractSolver interface) nemos assumes that objective functions return a scalar only.
As other models might make use of aux, this PR prepares the solver interface and current models to deal with that.
Main changes:
BaseRegressorhas a class attribute calledhas_auxthat is passed to the solver on instantiation. Models whose objective returns aux have to overwrite this.has_auxis now a required argument ofAbstractSolver.__init__.Remaining questions, tasks:
self.aux. Should this be namedself.aux_instead? Or returned?GLM.updateperformed the update, then estimated the scale using the previous parameters. Was that intended? I changed it to use the new parameters.has_auxis the only class attribute ofBaseRegressor. Is that okay or should it be stored somewhere else? In any case, I will add a note about it to the developer notes.Fixes #375