Skip to content

AttributeError: module 'jax.ops' has no attribute 'index_add' #1773

@cmosguy

Description

@cmosguy

Description

I am trying to do something basic in my code:

import numpy as np              # regular ol' numpy
from trax import layers as tl   # core building block
from trax import shapes         # data signatures: dimensionality and type
from trax import fastmath       # uses jax, offers numpy on steroids

Upon import it errors out doing the basics here. What am I doing wrong? Should I be pinning a different version of the code?

Environment information

OS: Cento
lsb_release
LSB Version: :core-4.1-amd64:core-4.1-ia32:core-4.1-noarch:cxx-4.1-amd64:cxx-4.1-ia32:cxx-4.1-noarch:desktop-4.1-amd64:desktop-4.1-ia32:desktop-4.1-noarch:languages-4.1-amd64:languages-4.1-noarch:printing-4.1-amd64:printing-4.1-noarch

$ pip freeze | grep trax
trax==1.3.9

$ pip freeze | grep tensor
mesh-tensorflow==0.1.21
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-datasets==4.8.2
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.30.0
tensorflow-metadata==1.12.0
tensorflow-text==2.11.0

$ pip freeze | grep jax
jax==0.4.4
jaxlib==0.4.4

$ python -V
Python 3.9.16


### For bugs: reproduction and error logs

# Error logs:

...

      1 # coding=utf-8
      2 # Copyright 2021 The Trax Authors.
      3 #
   (...)
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     16 """Trax top level import."""
---> 18 from trax import data
     19 from trax import fastmath
     20 from trax import layers

File ./ds_work/miniconda3/envs/coursera-nlp/lib/python3.9/site-packages/trax/data/__init__.py:36, in <module>
     16 """Functions and classes for obtaining and preprocesing data.
     17 
     18 The ``trax.data`` module presents a flattened (no subpackages) public API.
   (...)
...
    217     'vjp': jax.vjp,
    218     'vmap': jax.vmap,
    219 }

AttributeError: module 'jax.ops' has no attribute 'index_add'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions