Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shard_map doesn't work with jnp.insert #24762

Open
mrlazy1708 opened this issue Nov 7, 2024 · 0 comments
Open

shard_map doesn't work with jnp.insert #24762

mrlazy1708 opened this issue Nov 7, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@mrlazy1708
Copy link

Description

The code below

import jax
import jax.numpy as jnp

def f(x):
  return jnp.insert(x, 0, 0)[None]

from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P

mesh = Mesh(jax.devices("gpu"), axis_name:="test")
f = shard_map(f, mesh, P(axis_name), P(axis_name))
f(jnp.zeros(100))

raises the following error

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 193, in wrapped
    out_flat = shard_map_p.bind(
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 475, in bind
    outs = top_trace.process_shard_map(  # pytype: disable=attribute-error
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 810, in _shard_map_impl
    outs = fun.call_wrapped(*args)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/linear_util.py", line 193, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "<stdin>", line 2, in f
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 8495, in insert
    values_ind = indices.at[argsort(indices)].add(arange(n_insert, dtype=indices.dtype))
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/pjit.py", line 188, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **p.params)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 2803, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 442, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 1958, in process_primitive
    out_vals, out_reps = rule(self.mesh, in_reps, *in_vals, **params)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 1248, in _pjit_rewrite
    out_vals = pjit.pjit_p.bind(*args, jaxpr=jaxpr_, **kwargs)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 2803, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/_src/core.py", line 442, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 902, in process_primitive
    out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 1253, in _pjit_check
    return _check_rep(mesh, jaxpr.jaxpr, in_rep)
  File "~/.conda/envs/py310/lib/python3.10/site-packages/jax/experimental/shard_map.py", line 631, in _check_rep
    map(write, e.outvars, out_rep)
TypeError: 'NoneType' object is not iterable

Removing either the shmap or jnp.insert works as expected.

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.34
jaxlib: 0.4.34
numpy:  2.1.3
python: 3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0]
device info: NVIDIA RTX 6000 Ada Generation-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='**hidden**', release='5.15.0-92-generic', version='#102-Ubuntu SMP Wed Jan 10 09:33:48 UTC 2024', machine='x86_64')


$ nvidia-smi
Thu Nov  7 10:25:17 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
@mrlazy1708 mrlazy1708 added the bug Something isn't working label Nov 7, 2024
@mattjj mattjj self-assigned this Nov 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants