You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 193, in wrapped
out_flat = shard_map_p.bind(
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 475, in bind
outs = top_trace.process_shard_map( # pytype: disable=attribute-error
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 810, in _shard_map_impl
outs = fun.call_wrapped(*args)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "<stdin>", line 2, in f
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 8495, in insert
values_ind = indices.at[argsort(indices)].add(arange(n_insert, dtype=indices.dtype))
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/pjit.py", line 188, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **p.params)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 2803, in bind
return self.bind_with_trace(top_trace, args, params)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 442, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 1958, in process_primitive
out_vals, out_reps = rule(self.mesh, in_reps, *in_vals, **params)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 1248, in _pjit_rewrite
out_vals = pjit.pjit_p.bind(*args, jaxpr=jaxpr_, **kwargs)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 2803, in bind
return self.bind_with_trace(top_trace, args, params)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 442, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 902, in process_primitive
out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 1253, in _pjit_check
return _check_rep(mesh, jaxpr.jaxpr, in_rep)
File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 631, in _check_rep
map(write, e.outvars, out_rep)
TypeError: 'NoneType' object is not iterable
Removing either the shmap or jnp.insert works as expected.
System info (python version, jaxlib version, accelerator, etc.)
>>> import jax; jax.print_environment_info()
jax: 0.4.34
jaxlib: 0.4.34
numpy: 2.1.3
python: 3.10.15 (main, Oct 3 2024, 07:27:34) [GCC 11.2.0]
device info: NVIDIA RTX 6000 Ada Generation-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='**hidden**', release='5.15.0-92-generic', version='#102-Ubuntu SMP Wed Jan 10 09:33:48 UTC 2024', machine='x86_64')
$ nvidia-smi
Thu Nov 7 10:25:17 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
The text was updated successfully, but these errors were encountered:
Description
The code below
raises the following error
Removing either the
shmap
orjnp.insert
works as expected.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: