Skip to content

Silent memory explosion related to nnx.scan and closures #4871

@kvablack

Description

@kvablack

System information

Versions:

Name: flax
Version: 0.10.2
---
Name: jax
Version: 0.5.3
---
Name: jaxlib
Version: 0.5.3

Here is my reproducer:

from flax import nnx
import jax
import jax.numpy as jnp
import os

os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

DEPTH = 50
WIDTH = 10_000
NUM_ITERS = 200


class Block(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        super().__init__()
        self.param = nnx.Param(jax.random.normal(rngs(), (WIDTH, WIDTH)))

    def __call__(self, x):
        return self.param @ x


class TestModule(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        super().__init__()
        state_axes = nnx.StateAxes({...: 0})

        @nnx.split_rngs(splits=DEPTH)
        @nnx.vmap(out_axes=state_axes)
        def _create_blocks(rngs: nnx.Rngs):
            return Block(rngs)

        self.blocks = _create_blocks(rngs)

    def __call__(self, x):
        state_axes = nnx.StateAxes({...: 0})

        @nnx.scan(in_axes=(state_axes, nnx.Carry), out_axes=nnx.Carry)
        def _scan(block: Block, x):
            return block(x)

        x = jnp.broadcast_to(x, (WIDTH,))
        return _scan(self.blocks, x).sum()


module = TestModule(nnx.Rngs(0))
graphdef, state = nnx.split(module)


@jax.jit
def run(inner_state, x):
    inner_module = nnx.merge(graphdef, inner_state)
    # without this line, I get a tracer leak. but with this line, it's fine. why?????
    x = inner_module(x)

    def _scan_outer(i, _):
        return inner_module(i), None

    x, _ = jax.lax.scan(_scan_outer, x, length=NUM_ITERS)
    return x

run(state, 0)

The result is that, after creating the model, we have 1 copy of the parameters in memory, occupying about 20Gb. During run(state, 0), the parameters are copied unnecessarily, occupying 40Gb. I verified this by polling GPU memory as well as by running the JAX profiler. The extra copy of the parameters shows up in both the Memory Viewer as a large temporary buffer allocation, as well as in the Graph Viewer as a large, unnecessary copy operation.

Now, this is definitely the result of some bizarre interactions between NNX and scan. Replacing the nnx.scan with a split/merge and plain jax.lax.scan fixes the issue. Getting rid of the _scan_outer, and just calling x = inner_module(x) one or more times also fixes the issue. Keeping the _scan_outer, but removing the initial x = inner_module(x) call somehow leads to a tracer leak error, which is extremely confusing.

I realize that NNX has to do some fancy stuff to keep track of mutations to graph nodes inside transformations like scan. I also realize that tracking mutations while closing over modules is impossible (this is documented well 👍). However, this example doesn't involve any mutation, and furthermore, I would reckon that the majority of NNX use-cases actually do not care about mutation. It feels like closing over modules should generally behave as expected in these cases, rather than throwing a tracer leak error -- or worse, as in this example, silently causing huge memory allocations.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions