-
Notifications
You must be signed in to change notification settings - Fork 740
Description
System information
Versions:
Name: flax
Version: 0.10.2
---
Name: jax
Version: 0.5.3
---
Name: jaxlib
Version: 0.5.3
Here is my reproducer:
from flax import nnx
import jax
import jax.numpy as jnp
import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
DEPTH = 50
WIDTH = 10_000
NUM_ITERS = 200
class Block(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
super().__init__()
self.param = nnx.Param(jax.random.normal(rngs(), (WIDTH, WIDTH)))
def __call__(self, x):
return self.param @ x
class TestModule(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
super().__init__()
state_axes = nnx.StateAxes({...: 0})
@nnx.split_rngs(splits=DEPTH)
@nnx.vmap(out_axes=state_axes)
def _create_blocks(rngs: nnx.Rngs):
return Block(rngs)
self.blocks = _create_blocks(rngs)
def __call__(self, x):
state_axes = nnx.StateAxes({...: 0})
@nnx.scan(in_axes=(state_axes, nnx.Carry), out_axes=nnx.Carry)
def _scan(block: Block, x):
return block(x)
x = jnp.broadcast_to(x, (WIDTH,))
return _scan(self.blocks, x).sum()
module = TestModule(nnx.Rngs(0))
graphdef, state = nnx.split(module)
@jax.jit
def run(inner_state, x):
inner_module = nnx.merge(graphdef, inner_state)
# without this line, I get a tracer leak. but with this line, it's fine. why?????
x = inner_module(x)
def _scan_outer(i, _):
return inner_module(i), None
x, _ = jax.lax.scan(_scan_outer, x, length=NUM_ITERS)
return x
run(state, 0)
The result is that, after creating the model, we have 1 copy of the parameters in memory, occupying about 20Gb. During run(state, 0)
, the parameters are copied unnecessarily, occupying 40Gb. I verified this by polling GPU memory as well as by running the JAX profiler. The extra copy of the parameters shows up in both the Memory Viewer as a large temporary buffer allocation, as well as in the Graph Viewer as a large, unnecessary copy operation.
Now, this is definitely the result of some bizarre interactions between NNX and scan. Replacing the nnx.scan
with a split/merge and plain jax.lax.scan
fixes the issue. Getting rid of the _scan_outer
, and just calling x = inner_module(x)
one or more times also fixes the issue. Keeping the _scan_outer
, but removing the initial x = inner_module(x)
call somehow leads to a tracer leak error, which is extremely confusing.
I realize that NNX has to do some fancy stuff to keep track of mutations to graph nodes inside transformations like scan. I also realize that tracking mutations while closing over modules is impossible (this is documented well 👍). However, this example doesn't involve any mutation, and furthermore, I would reckon that the majority of NNX use-cases actually do not care about mutation. It feels like closing over modules should generally behave as expected in these cases, rather than throwing a tracer leak error -- or worse, as in this example, silently causing huge memory allocations.