-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for optax lbfgs and related optimizers with NNX #4144
Comments
An update: I have a partial hack to https://github.com/google/flax/blob/main/flax/nnx/nnx/training/optimizer.py If I replace the def update(self, grads, value = None, value_fn = None):
gdef, state = nnx.split(self.model, self.wrt)
def value_fn_wrapped(state):
model = nnx.merge(gdef, state)
return value_fn(model)
updates, new_opt_state = self.tx.update(grads, self.opt_state, state, grad = grads, value = value, value_fn = value_fn_wrapped)
new_params = optax.apply_updates(state, updates)
assert isinstance(new_params, nnx.State)
self.step.value += 1
nnx.update(self.model, new_params)
self.opt_state = new_opt_state Then it seems to work. The key is that the The full version of this (where I just called my function with the optimizer is: # Takes the baseline version and uses vmap, adds in a learning rate scheduler
import jax
import jax.numpy as jnp
from jax import random
import optax
import jax_dataloader as jdl
from jax_dataloader.loaders import DataLoaderJAX
from flax import nnx
N = 500 # samples
M = 2
sigma = 0.001
rngs = nnx.Rngs(42)
theta = random.normal(rngs(), (M,))
X = random.normal(rngs(), (N, M))
Y = X @ theta + sigma * random.normal(rngs(), (N,)) # Adding noise
def residual(model, x, y):
y_hat = model(x)
return (y_hat - y) ** 2
def residuals_loss(model, X, Y):
return jnp.mean(jax.vmap(residual, in_axes=(None, 0, 0))(model, X, Y))
model = nnx.Linear(M, 1, use_bias=False, rngs=rngs)
# From https://github.com/google/flax/blob/main/flax/nnx/nnx/training/optimizer.py
def update(self, grads, value = None, value_fn = None):
gdef, state = nnx.split(self.model, self.wrt)
def value_fn_wrapped(state):
model = nnx.merge(gdef, state)
return value_fn(model)
updates, new_opt_state = self.tx.update(grads, self.opt_state, state, grad = grads, value = value, value_fn = value_fn_wrapped)
new_params = optax.apply_updates(state, updates)
assert isinstance(new_params, nnx.State)
self.step.value += 1
nnx.update(self.model, new_params)
self.opt_state = new_opt_state
lr = 0.001
optimizer = nnx.Optimizer(model,
optax.lbfgs(),
#optax.sgd(lr),
)
@nnx.jit
def train_step(model, optimizer, X, Y):
def loss_fn(model):
return residuals_loss(model, X, Y)
loss, grads = nnx.value_and_grad(loss_fn, has_aux=False)(model)
# optimizer.update(grads)
update(optimizer, grads, value = loss, value_fn = loss_fn)
return loss
num_epochs = 20
batch_size = 1024
dataset = jdl.ArrayDataset(X, Y)
train_loader = DataLoaderJAX(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(num_epochs):
for X_batch, Y_batch in train_loader:
loss = train_step(model, optimizer, X_batch, Y_batch)
if epoch % 2 == 0:
print(
f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}"
)
print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}") This is using full-batch, which is approprriate for lbfgs unless the learning rate is decreased.
|
Just checking in on this issue. I noticed the underlying https://github.com/google/flax/blob/main/flax/nnx/training/optimizer.py code changed since I posted this workaround, so not sure if it is still a good idea or if there is a better was to reimplement |
@jlperla we changed how Optimizer stores the I looked at #4351 and at
The choice of having As an aside, if we had infinite resources it would be better to simply have NNX-native optimizers. I recently created a toy version of |
Yes, the key is really the I think I see what you mean about the kwargs for the update. I can modify my PR to see what you think. (for me at least, I am a little skeptical of writing many NNX native optimizers. I really like the optax gradient transformation infrastructure and its flexibility.) |
I am trying to use L-BFGS and related optimizers with nnx + optax, but running into trouble. It might be that
optax
has a slightly different optimization interface in those cases: https://optax.readthedocs.io/en/latest/api/optimizers.html#lbfgs with a function calledoptax.value_and_grad_from_state
and also a change to theoptimizer.update
interface?In particular, note that the sample code for these optimizers in https://optax.readthedocs.io/en/latest/api/optimizers.html#lbfgs looks like
So maybe we need the ability to pass a
value_fn
argument on? Any easy fixes or things I am missing?Problem you have encountered:
Take the following implementation LLS with NNX + optax
But if I change the optimizer from
optax.sgd(lr)
to beoptax.lbfgs()
then I would expect NNX to work.Logs, error messages, etc:
The error it gives is
Steps to reproduce:
Run the example above with the optimizer swapped out, i.e.
The text was updated successfully, but these errors were encountered: