Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5926fac
Add FISTA adaptation and adapter as OptimistixFISTA
bagibence Oct 9, 2025
618e769
Use OptimistixFISTA when using the Optimistix backend in tests
bagibence Oct 9, 2025
139990d
Add OptimistixGradientDescent that uses the accelerated JAXopt port
bagibence Oct 9, 2025
bb94b25
Use two_norm by default in Cauchy criterion
bagibence Oct 9, 2025
ab09fe5
Remove prox from accepted arguments in Optimistix solvers
bagibence Oct 9, 2025
9d630e0
Add todo
bagibence Oct 9, 2025
49fd795
Add docstrings
bagibence Oct 10, 2025
85a26d0
OptimistixGradientDescent -> OptimistixNAG
bagibence Oct 10, 2025
12a6f21
Remove OptimistixOptaxProximalGradient
bagibence Oct 10, 2025
2266383
Use tree_sub from tree_utils
bagibence Oct 10, 2025
ebf8b00
Use tree_add_scalar_mul from tree_utils
bagibence Oct 10, 2025
10401cd
Switch to backtracking in OptimistixOptaxGradientDescent
bagibence Oct 10, 2025
909a3d0
Satisfy linter
bagibence Oct 10, 2025
b802782
Change default tolerance to 1e-4
bagibence Oct 10, 2025
5a3f771
add todo
bagibence Oct 22, 2025
9b504f0
Update developer notes
bagibence Oct 24, 2025
b90c0de
Rename _fista_port.py to _fista.py
bagibence Nov 6, 2025
c6271b3
Change module docstring
bagibence Nov 6, 2025
9507ca9
Typing
bagibence Nov 6, 2025
f352c0a
Extend docstring
bagibence Nov 6, 2025
1e722d0
Fix import after file rename
bagibence Nov 6, 2025
2fc6f02
Add env vars to override solver implementation in tests
bagibence Nov 6, 2025
1f40c25
Run subset of tests with OptimistixOptaxGradientDescent
bagibence Nov 6, 2025
991bd84
Remove done TODO
bagibence Nov 6, 2025
04c91eb
Test both GD implementations in the same tox env
bagibence Nov 6, 2025
f59d35f
Fix and finish previous commit
bagibence Nov 6, 2025
0f76e18
Another fix: remove unnecessary line
bagibence Nov 6, 2025
8224834
Pass solver arguments whose name is also in OptimistixConfig
bagibence Nov 7, 2025
242ce18
First solution for the kind argument to the linesearch's while loop
bagibence Nov 7, 2025
a326540
Revert "First solution for the kind argument to the linesearch's whil…
bagibence Nov 7, 2025
2206d34
Revert "Pass solver arguments whose name is also in OptimistixConfig"
bagibence Nov 7, 2025
deda0d8
Another solution for deriving "kind" for the while loop
bagibence Nov 11, 2025
9686476
Mofify solver_init_kwargs instead of a separate dict for derived params
bagibence Nov 11, 2025
7c06e7b
Add tests and remove unused import
bagibence Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions docs/developers_notes/07-solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ Abstract Class AbstractSolver
│ │ │
│ │ ├─ Concrete Subclass OptimistixBFGS
│ │ ├─ Concrete Subclass OptimistixLBFGS
│ │ ├─ Concrete Subclass OptimistixFISTA
│ │ ├─ Concrete Subclass OptimistixNAG
│ │ ├─ Concrete Subclass OptimistixNonlinearCG
│ │ └─ Abstract Subclass AbstractOptimistixOptaxSolver
│ │ │
│ │ ├─ Concrete Subclass OptimistixOptaxLBFGS
│ │ ├─ Concrete Subclass OptimistixOptaxGradientDescent
│ │ └─ Concrete Subclass OptimistixOptaxProximalGradient
│ │ └─ Concrete Subclass OptimistixOptaxGradientDescent
│ │
│ └─ Abstract Subclass JaxoptAdapter
│ │
Expand All @@ -86,8 +87,13 @@ Abstract Class AbstractSolver
```

`OptaxOptimistixSolver` is an adapter for Optax solvers, relying on `optimistix.OptaxMinimiser` to run the full optimization loop.
Optimistix does not have implementations of Nesterov acceleration, so gradient descent is implemented by wrapping `optax.sgd` which does support it.
(Although what Optax calls Nesterov acceleration is not the [original method developed for convex optimization](https://hengshuaiyao.github.io/papers/nesterov83.pdf) but the [version adapted for training deep networks with SGD](https://proceedings.mlr.press/v28/sutskever13.html). JAXopt did implement the original method, and [a port of this is planned to be added to NeMoS](https://github.com/flatironinstitute/nemos/issues/380).)

Gradient descent is implemented by two classes:
- One is wrapping `optax.sgd` which supports momentum and acceleration.
Note that what Optax calls Nesterov acceleration is not the [original method developed for convex optimization](https://hengshuaiyao.github.io/papers/nesterov83.pdf) but the [version adapted for training deep networks with SGD](https://proceedings.mlr.press/v28/sutskever13.html).
- As JAXopt implemented the original method, a [port of JAXopt's `GradientDescent` was added to NeMoS](https://github.com/flatironinstitute/nemos/pull/411) as `OptimistixNAG`.

Similarly to NAG, an accelerated proximal gradient algorithm ([FISTA](https://www.ceremade.dauphine.fr/~carlier/FISTA)) was [ported from JAXopt](https://github.com/flatironinstitute/nemos/pull/411) as `OptimistixFISTA`.

Available solvers and which implementation they dispatch to are defined in the solver registry.
A list of available solvers is provided by {py:func}`nemos.solvers.list_available_solvers`, and extended documentation about each solver can be accessed using {py:func}`nemos.solvers.get_solver_documentation`.
Expand Down
1 change: 1 addition & 0 deletions scripts/check_parameter_naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
{"solver_kwargs", "solver_init_kwargs"},
{"unaccepted_name", "accepted_name"},
{"fn", "fun"},
{"ax", "aux"},
]


Expand Down
2 changes: 1 addition & 1 deletion src/nemos/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Custom solvers module."""

from ._fista import OptimistixFISTA, OptimistixNAG
from ._jaxopt_solvers import (
JaxoptBFGS,
JaxoptGradientDescent,
Expand All @@ -10,7 +11,6 @@
from ._optax_optimistix_solvers import (
OptimistixOptaxGradientDescent,
OptimistixOptaxLBFGS,
OptimistixOptaxProximalGradient,
)
from ._optimistix_solvers import OptimistixBFGS, OptimistixNonlinearCG
from ._solver_doc_helper import get_solver_documentation
Expand Down
Loading