Skip to content

Could not normally run trax using GPU in local computer #1778

@LiuZhenshun

Description

@LiuZhenshun

Description

Hi, I would like to install trax locally. Firstly, I found jax I installed is not suitable for GPU, so I follow the jax github to install Cuda version jax. Next, I validate jax could detect GPU in my local computer, but I could not run the sample code like transfomer and fast math.

Environment information

OS: Pop-os(based on ubuntu 22.04)

$ pip freeze | grep trax
# trax==1.4.1

$ pip freeze | grep tensor
# tensorboard==2.12.3
# tensorboard-data-server==0.7.1
# tensorflow==2.12.0
# tensorflow-datasets==4.9.2
# tensorflow-estimator==2.12.0
# tensorflow-hub==0.13.0
# tensorflow-io-gcs-filesystem==0.32.0
# tensorflow-metadata==1.13.1
# tensorflow-text==2.12.1

$ pip freeze | grep jax
# jax==0.4.12
# jaxlib==0.4.12+cuda11.cudnn86

$ python -V
# Python 3.11.3

For bugs: reproduction and error logs

# Steps to reproduce:
1) Install trax

- pip install trax
- pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

2) Use jax Detect GPU
- code:
    import jax
    print(jax.devices()) 
- output:
    [gpu(id=0)]
# Error logs:
1) Run the sample code of pre-trained transformer in your Realme tutorial
- code:
      import os
      import numpy as np
      
      import trax
      
      # Create a Transformer model.
      # Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
      model = trax.models.Transformer(
          input_vocab_size=33300,
          d_model=512, d_ff=2048,
          n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
          max_len=64, mode='predict')
      
      # Initialize using pre-trained weights.
      model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
                           weights_only=True)
                          #  input_signature=input_signature)
      
      # Tokenize a sentence.
      sentence = 'It is nice to learn new things today!'
      tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
                                          vocab_dir='gs://trax-ml/vocabs/',
                                          vocab_file='ende_32k.subword'))[0]
      
      # Decode from the Transformer.
      tokenized = tokenized[None, :]  # Add batch dimension.
      tokenized_translation = trax.supervised.decoding.autoregressive_sample(
          model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.
      
      # De-tokenize,
      tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
      translation = trax.data.detokenize(tokenized_translation,
                                         vocab_dir='gs://trax-ml/vocabs/',
                                         vocab_file='ende_32k.subword')
      print(translation)
- Error Output:
      2023-06-22 15:58:35.266959: W tensorflow/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".
      2023-06-22 15:58:56.630331: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
      INTERNAL: Failed to get stream's capture status: out of memory
      2023-06-22 15:58:56.630403: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2461] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to get stream's capture status: out of memory; current tracing scope: fusion; current profiling annotation: XlaModule:#hlo_module=jit_PRNGKey,program_id=0#.
      Traceback (most recent call last):
        File "/home/littleliu/Documents/project/trax_learning/tryTrax.py", line 22, in <module>
          model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/trax/layers/base.py", line 349, in init_from_file
          self.init(input_signature)
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/trax/layers/base.py", line 310, in init
          raise LayerError(name, 'init', self._caller,
      trax.layers.base.LayerError: Exception passing through layer Serial (in init):
        layer created in file [...]/trax/models/transformer.py, line 371
        layer input shapes: (ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:float32})
      
        File [...]/trax/layers/combinators.py, line 108, in init_weights_and_state
          outputs, _ = sublayer._forward_abstract(inputs)
      
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File [...]/trax/layers/base.py, line 641, in _forward_abstract
      
        layer created in file [...]/trax/models/transformer.py, line 372
        layer input shapes: (ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64})
      
      jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to get stream's capture status: out of memory; current tracing scope: fusion; current profiling annotation: XlaModule:#hlo_module=jit_PRNGKey,program_id=0#.

2) Run the sample code of Fast Math:
- code:
      import trax
      from trax.fastmath import numpy as fastnp
      trax.fastmath.use_backend('jax')  # Can be 'jax' or 'tensorflow-numpy'.
      
      matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
      print(f'matrix =\n{matrix}')
      vector = fastnp.ones(3)
      print(f'vector = {vector}')
      product = fastnp.dot(vector, matrix)
      print(f'product = {product}')
      tanh = fastnp.tanh(product)
      print(f'tanh(product) = {tanh}')
- Error Output:
      matrix =
      [[1 2 3]
       [4 5 6]
       [7 8 9]]
      2023-06-22 16:03:23.041313: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
      2023-06-22 16:03:23.041386: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 36175872 bytes free, 4093902848 bytes total.
      2023-06-22 16:03:23.041476: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:453] Possibly insufficient driver version: 525.85.5
      Traceback (most recent call last):
        File "/home/littleliu/Documents/project/trax_learning/fastnumpy.py", line 7, in <module>
          vector = fastnp.ones(3)
                   ^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2161, in ones
          return lax.full(shape, 1, _jnp_dtype(dtype))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 1205, in full
          return broadcast(fill_value, shape)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 768, in broadcast
          return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 796, in broadcast_in_dim
          return broadcast_in_dim_p.bind(
                 ^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 380, in bind
          return self.bind_with_trace(find_top_trace(args), args, params)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 383, in bind_with_trace
          out = trace.process_primitive(self, map(trace.full_raise, args), params)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 790, in process_primitive
          return primitive.impl(*tracers, **params)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 132, in apply_primitive
          compiled_fun = xla_primitive_callable(
                         ^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/util.py", line 284, in wrapper
          return cached(config._trace_context(), *args, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/util.py", line 277, in cached
          return f(*args, **kwargs)
                 ^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 223, in xla_primitive_callable
          compiled = _xla_callable_uncached(
                     ^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 253, in _xla_callable_uncached
          return computation.compile().unsafe_call
                 ^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2329, in compile
          executable = UnloadedMeshExecutable.from_hlo(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2651, in from_hlo
          xla_executable, compile_options = _cached_compilation(
                                            ^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2561, in _cached_compilation
          xla_executable = dispatch.compile_or_get_cached(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
          return backend_compile(backend, computation, compile_options,
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
          return func(*args, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
          return backend.compile(built_c, compile_options=options)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions