Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nnx.vmap example from documentation raise an index error #4355

Open
jhn-nt opened this issue Nov 4, 2024 · 5 comments
Open

nnx.vmap example from documentation raise an index error #4355

jhn-nt opened this issue Nov 4, 2024 · 5 comments

Comments

@jhn-nt
Copy link

jhn-nt commented Nov 4, 2024

I am encountering an index error when running this example in the documentation

I am running the code in a docker environment using an NVIDIA image for jax.

Best
Giovanni

System information

  • Docker Image ghcr.io/nvidia/jax:base
  • flax==0.9.0
  • jax[cuda12_local]==0.4.34

Problem you have encountered:

Error while going through nnx tutorial

What you expected to happen:

Logs, error messages, etc:


IndexError Traceback (most recent call last)
Cell In[8], line 23
19 @partial(nnx.vmap, axis_size=5)
20 def create_model(rngs: nnx.Rngs):
21 return MLP(10, 32, 10, rngs=rngs)
---> 23 model = create_model(nnx.Rngs(0))

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/graph.py:1158, in UpdateContextManager.call..update_context_manager_wrapper(*args, **kwargs)
1155 @functools.wraps(f)
1156 def update_context_manager_wrapper(*args, **kwargs):
1157 with self:
-> 1158 return f(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/iteration.py:339, in vmap..vmap_wrapper(*args, **kwargs)
335 args = resolve_kwargs(f, args, kwargs)
336 pure_args = extract.to_tree(
337 args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='vmap'
338 )
--> 339 pure_args_out, pure_out = vmapped_fn(*pure_args)
340 _args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag='vmap')
341 return out

[... skipping hidden 3 frame]

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/iteration.py:164, in VmapFn.call(self, *pure_args)
159 pure_args = _update_variable_sharding_metadata(
160 pure_args, self.transform_metadata, spmd.remove_axis
161 )
162 args = extract.from_tree(pure_args, ctxtag='vmap')
--> 164 out = self.f(*args)
166 args_out = extract.clear_non_graph_nodes(args)
167 pure_args_out, pure_out = extract.to_tree(
168 (args_out, out),
169 prefix=(self.in_axes, self.out_axes),
170 split_fn=_vmap_split_fn,
171 ctxtag='vmap',
172 )

Cell In[8], line 21
19 @partial(nnx.vmap, axis_size=5)
20 def create_model(rngs: nnx.Rngs):
---> 21 return MLP(10, 32, 10, rngs=rngs)

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:79, in ObjectMeta.call(cls, *args, **kwargs)
78 def call(cls, *args: Any, **kwargs: Any) -> Any:
---> 79 return _graph_node_meta_call(cls, *args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:88, in _graph_node_meta_call(cls, *args, **kwargs)
86 node = cls.new(cls, *args, **kwargs)
87 vars(node)['_object__state'] = ObjectState()
---> 88 cls._object_meta_construct(node, *args, **kwargs)
90 return node

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:82, in ObjectMeta._object_meta_construct(cls, self, *args, **kwargs)
81 def _object_meta_construct(cls, self, *args, **kwargs):
---> 82 self.init(*args, **kwargs)

Cell In[8], line 8
7 def init(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
----> 8 self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
9 self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
10 self.bn = nnx.BatchNorm(dmid, rngs=rngs)

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:79, in ObjectMeta.call(cls, *args, **kwargs)
78 def call(cls, *args: Any, **kwargs: Any) -> Any:
---> 79 return _graph_node_meta_call(cls, *args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:88, in _graph_node_meta_call(cls, *args, **kwargs)
86 node = cls.new(cls, *args, **kwargs)
87 vars(node)['_object__state'] = ObjectState()
---> 88 cls._object_meta_construct(node, *args, **kwargs)
90 return node

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:82, in ObjectMeta._object_meta_construct(cls, self, *args, **kwargs)
81 def _object_meta_construct(cls, self, *args, **kwargs):
---> 82 self.init(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/nn/linear.py:346, in Linear.init(self, in_features, out_features, use_bias, dtype, param_dtype, precision, kernel_init, bias_init, dot_general, rngs)
332 def init(
333 self,
334 in_features: int,
(...)
344 rngs: rnglib.Rngs,
345 ):
--> 346 kernel_key = rngs.params()
347 self.kernel = nnx.Param(
348 kernel_init(kernel_key, (in_features, out_features), param_dtype)
349 )
350 if use_bias:

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/rnglib.py:84, in RngStream.call(self)
80 def call(self) -> jax.Array:
81 self.check_valid_context(
82 lambda: 'Cannot call RngStream from a different trace level'
83 )
---> 84 key = jax.random.fold_in(self.key.value, self.count.value)
85 self.count.value += 1
86 return key

File /usr/local/lib/python3.10/dist-packages/jax/_src/random.py:262, in fold_in(key, data)
251 def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray:
252 """Folds in data to a PRNG key to form a new PRNG key.
253
254 Args:
(...)
260 statistically safe for producing a stream of new pseudo-random values.
261 """
--> 262 key, wrapped = _check_prng_key("fold_in", key)
263 if np.ndim(data):
264 raise TypeError("fold_in accepts a scalar, but was given an array of"
265 f"shape {np.shape(data)} != (). Use jax.vmap for batching.")

File /usr/local/lib/python3.10/dist-packages/jax/_src/random.py:74, in _check_prng_key(name, key, allow_batched)
72 def _check_prng_key(name: str, key: KeyArrayLike, *,
73 allow_batched: bool = False) -> tuple[KeyArray, bool]:
---> 74 if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key):
75 wrapped_key = key
76 wrapped = False

[... skipping hidden 1 frame]

File /usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/batching.py:346, in BatchTracer.aval(self)
344 return aval
345 elif type(self.batch_dim) is int:
--> 346 return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
347 elif type(self.batch_dim) is RaggedAxis:
348 new_aval = core.mapped_aval(
349 aval.shape[self.batch_dim.stacked_axis], self.batch_dim.stacked_axis, aval)

IndexError: tuple index out of range

Steps to reproduce:

from flax import nnx
import jax
from functools import partial


class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)



@partial(nnx.vmap, axis_size=5)
def create_model(rngs: nnx.Rngs):
  return MLP(10, 32, 10, rngs=rngs)

model = create_model(nnx.Rngs(0))
@jhn-nt
Copy link
Author

jhn-nt commented Nov 4, 2024

Same behavior when updating to flax==0.10.0 and jax[cuda12_local]==0.4.35

@cgarciae
Copy link
Collaborator

cgarciae commented Nov 5, 2024

Hey @jhn-nt, thanks for reporting this! Very curious why our CI is not failing. Easiest fix is to split the keys for the Rngs:

keys = jax.random.split(jax.random.key(0), 5)
model = create_model(nnx.Rngs(keys))

Will fix this quickly.

@jhn-nt
Copy link
Author

jhn-nt commented Nov 5, 2024

Thanks a lot again for the prompt help!

Giovanni

@cgarciae
Copy link
Collaborator

cgarciae commented Nov 5, 2024

Oh wait, the link you posted is for the old experimental docs in the 0.8.3 version of the site, this is fixed in the new version: https://flax.readthedocs.io/en/latest/nnx_basics.html#scan-over-layers . Did you find this via Google?

@jhn-nt
Copy link
Author

jhn-nt commented Nov 6, 2024

Uh I see, that explaines it then, apologies for opening the issue,
I should have checked in more detail

But, yes, I find it through google, searching for "flax nnx"

Giovanni

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants