Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse reshape throws error when n_dense>0 and some target dimension has size 1 #24795

Open
cherrywoods opened this issue Nov 8, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@cherrywoods
Copy link

Description

Reshape for sparse BCOO arrays fails if the target shape contains dimensions of size 1 and there is at least one dense dimension.

from jax.experimental import sparse

sp_id = sparse.eye(2, n_dense=1)
sp_id.reshape((1, 2, 1, 2))

Stack trace:

  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/experimental/sparse/transform.py", line 451, in wrapped
    result = eval_sparse(jaxpr, consts, spvalues_flat, spenv)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/experimental/sparse/transform.py", line 428, in eval_sparse
    out = sparse_rules_bcoo[prim](spenv, *invals, **eqn.params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/experimental/sparse/transform.py", line 539, in _sparse_rule
    result = sparse_op(*spvalues_to_arrays(spenv, spvalues), **kwds)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/experimental/sparse/bcoo.py", line 1858, in bcoo_reshape
    data = lax.reshape(
           ^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 919, in reshape
    return reshape_p.bind(
           ^^^^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/_src/core.py", line 416, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/_src/core.py", line 420, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/_src/core.py", line 921, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/david/.miniconda3/envs/formalax/lib/python3.12/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
TypeError: reshape total size must be unchanged, got new_sizes (1, 2) (of total size 2) for shape (2, 2) (of total size 4).

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35
jaxlib: 0.4.35
numpy:  2.0.1
python: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='abc', release='6.8.0-48-generic', version='#48-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 27 14:04:52 UTC 2024', machine='x86_64')


$ nvidia-smi
Fri Nov  8 18:47:48 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce GT 1030         Off | 00000000:1C:00.0  On |                  N/A |
| 35%   41C    P0              N/A /  30W |    589MiB /  2048MiB |     30%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      3509      G   /usr/lib/xorg/Xorg                          163MiB |
|    0   N/A  N/A      3774      G   /usr/bin/gnome-shell                         90MiB |
|    0   N/A  N/A      3956      G   ...irefox/5187/usr/lib/firefox/firefox      120MiB |
|    0   N/A  N/A      3998      G   ...usr/lib/thunderbird/thunderbird-bin        7MiB |
|    0   N/A  N/A      4177      G   ...esktop-client/214/usr/bin/nextcloud        0MiB |
|    0   N/A  N/A      4641      G   /usr/libexec/xdg-desktop-portal-gnome       123MiB |
|    0   N/A  N/A      5609      G   ...yOnDemand --variations-seed-version       14MiB |
|    0   N/A  N/A      9050      G   /usr/bin/nautilus                            31MiB |
|    0   N/A  N/A     25941      G   /usr/bin/gnome-calendar                      16MiB |
|    0   N/A  N/A     75544      G   /usr/bin/gnome-system-monitor                11MiB |
|    0   N/A  N/A     75930      G   ...erProcess --variations-seed-version        1MiB |
+---------------------------------------------------------------------------------------+

Note:

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
@cherrywoods cherrywoods added the bug Something isn't working label Nov 8, 2024
@cherrywoods
Copy link
Author

I came up with a hotfix but I don't think what I implemented is the generally desirable behaviour. My fix is to insert these lines:

 while i1 > 0 and new_sizes[i1 - 1] == 1:
    i1 -= 1
 while i2 > 0 and new_sizes[i2 - 1] == 1:
    i2 -= 1

after

i2 = cuml_shape.searchsorted(batch_size * sparse_size, side='right')

This moves all dimensions of size one to the next dimension kind (batch -> sparse, sparse -> dense). This works in my case, but I figure there could be other cases where this might be precisely the wrong thing to do?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant