Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for optax lbfgs and related optimizers with NNX #4144

Closed
jlperla opened this issue Aug 26, 2024 · 4 comments · Fixed by #4351
Closed

Support for optax lbfgs and related optimizers with NNX #4144

jlperla opened this issue Aug 26, 2024 · 4 comments · Fixed by #4351
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)

Comments

@jlperla
Copy link
Contributor

jlperla commented Aug 26, 2024

I am trying to use L-BFGS and related optimizers with nnx + optax, but running into trouble. It might be that optax has a slightly different optimization interface in those cases: https://optax.readthedocs.io/en/latest/api/optimizers.html#lbfgs with a function called optax.value_and_grad_from_state and also a change to the optimizer.update interface?

In particular, note that the sample code for these optimizers in https://optax.readthedocs.io/en/latest/api/optimizers.html#lbfgs looks like

opt_state = solver.init(params)
value_and_grad = optax.value_and_grad_from_state(f)
for _ in range(5):
  value, grad = value_and_grad(params, state=opt_state)
  updates, opt_state = solver.update(
     grad, opt_state, params, value=value, grad=grad, value_fn=f
  )
  params = optax.apply_updates(params, updates)

So maybe we need the ability to pass a value_fn argument on? Any easy fixes or things I am missing?

Problem you have encountered:

Take the following implementation LLS with NNX + optax

import jax
import jax.numpy as jnp
from jax import random
import optax
import equinox as eqx
import jax_dataloader as jdl
from jax_dataloader.loaders import DataLoaderJAX
from flax import nnx

N = 64  # samples
M = 2
sigma = 0.001
rngs = nnx.Rngs(42)
theta = random.normal(rngs(), (M,))
X = random.normal(rngs(), (N, M))
Y = X @ theta + sigma * random.normal(rngs(), (N,))  # Adding noise

def residual(model, x, y):
    y_hat = model(x)
    return (y_hat - y) ** 2

def residuals_loss(model, X, Y):
    return jnp.mean(jax.vmap(residual, in_axes=(None, 0, 0))(model, X, Y))

model = nnx.Linear(M, 1, use_bias=False, rngs=rngs)

lr = 0.001
optimizer = nnx.Optimizer(model,
                          optax.sgd(lr) #optax.lbfgs()
                          )

@nnx.jit
def train_step(model, optimizer, X, Y):
    grad_fn = nnx.value_and_grad(residuals_loss, has_aux=False)
    loss, grads = grad_fn(model, X, Y)
    optimizer.update(grads)
    return loss

num_epochs = 500
batch_size = 64
dataset = jdl.ArrayDataset(X, Y)
train_loader = DataLoaderJAX(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(num_epochs):
    for X_batch, Y_batch in train_loader:
        loss = train_step(model, optimizer, X_batch, Y_batch)

    if epoch % 100 == 0:
        print(
            f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}"
        )

print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}")

But if I change the optimizer from optax.sgd(lr) to be optax.lbfgs() then I would expect NNX to work.

Logs, error messages, etc:

The error it gives is

  File "/Users/jlperla/Documents/GitHub/ECON622_instructor/lectures/examples/linear_regression_jax_nnx.py", line 52, in <module>
    loss = train_step(model, optimizer, X_batch, Y_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/flax/nnx/nnx/graph.py", line 1043, in update_context_manager_wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/flax/nnx/nnx/transforms/transforms.py", line 359, in jit_wrapper
    out, output_state, output_graphdef = jitted_fn(
                                         ^^^^^^^^^^
  File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/flax/nnx/nnx/transforms/transforms.py", line 158, in jit_fn
    out = f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^
  File "/Users/jlperla/Documents/GitHub/ECON622_instructor/lectures/examples/linear_regression_jax_nnx.py", line 43, in train_step
    optimizer.update(grads)
  File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/flax/nnx/nnx/training/optimizer.py", line 201, in update
    updates, new_opt_state = self.tx.update(grads, self.opt_state, state)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jlperla/anaconda3/envs/econ622/lib/python3.11/site-packages/optax/transforms/_combining.py", line 73, in update_fn
    updates, new_s = fn(updates, s, params, **extra_args)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: scale_by_zoom_linesearch.<locals>.update_fn() missing 3 required keyword-only arguments: 'value', 'grad', and 'value_fn'

Steps to reproduce:

Run the example above with the optimizer swapped out, i.e.

import jax
import jax.numpy as jnp
from jax import random
import optax
import equinox as eqx
import jax_dataloader as jdl
from jax_dataloader.loaders import DataLoaderJAX
from flax import nnx

N = 64  # samples
M = 2
sigma = 0.001
rngs = nnx.Rngs(42)
theta = random.normal(rngs(), (M,))
X = random.normal(rngs(), (N, M))
Y = X @ theta + sigma * random.normal(rngs(), (N,))  # Adding noise

def residual(model, x, y):
    y_hat = model(x)
    return (y_hat - y) ** 2

def residuals_loss(model, X, Y):
    return jnp.mean(jax.vmap(residual, in_axes=(None, 0, 0))(model, X, Y))

model = nnx.Linear(M, 1, use_bias=False, rngs=rngs)

lr = 0.001
optimizer = nnx.Optimizer(model,
                          optax.lbfgs()
                          )

@nnx.jit
def train_step(model, optimizer, X, Y):
    grad_fn = nnx.value_and_grad(residuals_loss, has_aux=False)
    loss, grads = grad_fn(model, X, Y)
    optimizer.update(grads)
    return loss

num_epochs = 500
batch_size = 64
dataset = jdl.ArrayDataset(X, Y)
train_loader = DataLoaderJAX(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(num_epochs):
    for X_batch, Y_batch in train_loader:
        loss = train_step(model, optimizer, X_batch, Y_batch)

    if epoch % 100 == 0:
        print(
            f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}"
        )

print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}")
@google google deleted a comment Aug 26, 2024
@cgarciae cgarciae added the Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) label Aug 26, 2024
@jlperla
Copy link
Contributor Author

jlperla commented Aug 27, 2024

An update: I have a partial hack to https://github.com/google/flax/blob/main/flax/nnx/nnx/training/optimizer.py

If I replace the .update with

def update(self, grads, value = None, value_fn = None):
    gdef, state = nnx.split(self.model, self.wrt)

    def value_fn_wrapped(state):
        model = nnx.merge(gdef, state)
        return value_fn(model)

    updates, new_opt_state = self.tx.update(grads, self.opt_state, state, grad = grads, value = value, value_fn = value_fn_wrapped)


    new_params = optax.apply_updates(state, updates)
    assert isinstance(new_params, nnx.State)

    self.step.value += 1
    nnx.update(self.model, new_params)
    self.opt_state = new_opt_state

Then it seems to work. The key is that the value_fn required by linesearch cannot take the state in when it evaluates, so you ned to split and merge. Don't know if this is high performance or not. It certainly isn't using the grad and value caching from optax.value_and_grad_from_state

The full version of this (where I just called my function with the optimizer is:

# Takes the baseline version and uses vmap, adds in a learning rate scheduler
import jax
import jax.numpy as jnp
from jax import random
import optax
import jax_dataloader as jdl
from jax_dataloader.loaders import DataLoaderJAX
from flax import nnx

N = 500  # samples
M = 2
sigma = 0.001
rngs = nnx.Rngs(42)
theta = random.normal(rngs(), (M,))
X = random.normal(rngs(), (N, M))
Y = X @ theta + sigma * random.normal(rngs(), (N,))  # Adding noise

def residual(model, x, y):
    y_hat = model(x)
    return (y_hat - y) ** 2

def residuals_loss(model, X, Y):
    return jnp.mean(jax.vmap(residual, in_axes=(None, 0, 0))(model, X, Y))

model = nnx.Linear(M, 1, use_bias=False, rngs=rngs)

# From https://github.com/google/flax/blob/main/flax/nnx/nnx/training/optimizer.py
def update(self, grads, value = None, value_fn = None):
    gdef, state = nnx.split(self.model, self.wrt)

    def value_fn_wrapped(state):
        model = nnx.merge(gdef, state)
        return value_fn(model)

    updates, new_opt_state = self.tx.update(grads, self.opt_state, state, grad = grads, value = value, value_fn = value_fn_wrapped)


    new_params = optax.apply_updates(state, updates)
    assert isinstance(new_params, nnx.State)

    self.step.value += 1
    nnx.update(self.model, new_params)
    self.opt_state = new_opt_state




lr = 0.001
optimizer = nnx.Optimizer(model,
                          optax.lbfgs(),
                           #optax.sgd(lr),
                          )

@nnx.jit
def train_step(model, optimizer, X, Y):
    def loss_fn(model):
        return residuals_loss(model, X, Y)
    loss, grads =  nnx.value_and_grad(loss_fn, has_aux=False)(model)
    # optimizer.update(grads)
    update(optimizer, grads, value = loss, value_fn = loss_fn)
    return loss

num_epochs = 20
batch_size = 1024
dataset = jdl.ArrayDataset(X, Y)
train_loader = DataLoaderJAX(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(num_epochs):
    for X_batch, Y_batch in train_loader:
        loss = train_step(model, optimizer, X_batch, Y_batch)

    if epoch % 2 == 0:
        print(
            f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}"
        )

print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - jnp.squeeze(model.kernel.value))}")

This is using full-batch, which is approprriate for lbfgs unless the learning rate is decreased.

@jlperla
Copy link
Contributor Author

jlperla commented Oct 29, 2024

Just checking in on this issue. I noticed the underlying https://github.com/google/flax/blob/main/flax/nnx/training/optimizer.py code changed since I posted this workaround, so not sure if it is still a good idea or if there is a better was to reimplement update(self, grads, value = None, value_fn = None)

@cgarciae
Copy link
Collaborator

@jlperla we changed how Optimizer stores the opt_state to make it possible use the Variable metadata from the parameters on the optimizer states (now wrapped inside their on Variables). I believe you should be able to fix your custom version using the new definitions.

I looked at #4351 and at optax.lbfgs and realized that lbfgs implements a different interface called GradientTransformationExtraArgs which is different from GradientTransformation. Based on this I'm wondering if we should either:

  1. Add **kwargs to Optimizer.update which simply are forwarded tx.update and let the user handle the complexity.
  2. Implement a new optimizer type to make this family of algorithms easy to use.

The choice of having Optimizer support lbfgs specifically via having a default definition for value_fn feels suboptimal. 1 seems very attractive and I think we should just do it, there's nothing stopping the user from implementing the equivalent of value_fn_wrapped on their own, we could add an example to the docs to make it easy to use the pattern. 2 seems a little more involved as it would add more maintenance but its arguably a better solution.

As an aside, if we had infinite resources it would be better to simply have NNX-native optimizers. I recently created a toy version of SDG in 10_fsdp_and_optimizer.py and found it to be very readable. Doing the same for LBFGS would be interesting.

@jlperla
Copy link
Contributor Author

jlperla commented Oct 31, 2024

Yes, the key is really the GradientTransformationExtraArgs. This is not specific to lbfgs() (which, I think, is convenience wrapper). The key is the linesearch, which can be used with other optimizers. See https://optax.readthedocs.io/en/latest/api/transformations.html#optax.scale_by_backtracking_linesearch for an example that uses linesearch with sgd.

I think I see what you mean about the kwargs for the update. I can modify my PR to see what you think.

(for me at least, I am a little skeptical of writing many NNX native optimizers. I really like the optax gradient transformation infrastructure and its flexibility.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants