You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I wanted to implement truncated backprop to save memory while unrolling e.g. in meta-learning tasks.
I understand that this method relies on hiding some computation (specifically the initial few steps) from the AD system.
In tensorflow or pytorch, one can engage or disengage the "tape" which records the operations and thereby easily apply truncated backprop. What would be the suggested way to do it in JAX?
import jax
import jax.numpy as jnp
# -------------------
# Inner optimization
# -------------------
def inner_loss(w):
return jnp.sin(w) # simple toy loss
def init_inner(_lambda):
return 5.0 * _lambda # initialization depends on lambda
def inner_update(w, _lambda):
grad_w = jax.grad(inner_loss)(w)
return w - 0.1 * _lambda * grad_w
def run_inner_optimization(w0, _lambda, nsteps=20):
def step_fn(w, _):
w_new = inner_update(w, _lambda)
return w_new, w_new
w_final, traj = jax.lax.scan(step_fn, w0, None, length=nsteps)
return w_final, traj # return both final weight and trajectory
# -------------------
# Outer optimization
# -------------------
def outer_loss(_lambda):
w0 = init_inner(_lambda)
w_temp, _ = run_inner_optimization(jax.lax.stop_gradient(w0), jax.lax.stop_gradient(_lambda), 10)
w_final, _ = run_inner_optimization(w_temp, _lambda, 10)
return (w_final - 1.0) ** 2 # want final weight close to 1
outer_grad = jax.grad(outer_loss)
# -------------------
# Test run
# -------------------
_lambda = 2.0
print("Outer loss:", outer_loss(_lambda))
print("Outer grad wrt λ:", outer_grad(_lambda))
Is this the correct way to truncate the first 10 steps ? Will this mitigate the memory issue?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hey everyone,
I wanted to implement truncated backprop to save memory while unrolling e.g. in meta-learning tasks.
I understand that this method relies on hiding some computation (specifically the initial few steps) from the AD system.
In tensorflow or pytorch, one can engage or disengage the "tape" which records the operations and thereby easily apply truncated backprop. What would be the suggested way to do it in JAX?
Is this the correct way to truncate the first 10 steps ? Will this mitigate the memory issue?
Beta Was this translation helpful? Give feedback.
All reactions