Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Both API example notebooks have stopped working on Colab #234

Open
nalzok opened this issue May 16, 2022 · 5 comments
Open

[Bug] Both API example notebooks have stopped working on Colab #234

nalzok opened this issue May 16, 2022 · 5 comments
Labels
bug Something isn't working

Comments

@nalzok
Copy link

nalzok commented May 16, 2022

Describe the bug

Colab runtime freezes during model.fit. It has been running for minutes without any process. The progress bar always shows

Epoch 1/100
196/200 [============================>.] - ETA: 0s - accuracy: 0.7653 - crossentropy_loss: 0.8751 - l2_loss: 0.0364 - loss: 0.9114

When I tried to interrupt the cell execution, Colab promotes The executing code is not responding to interrupts. Would you like to try restarting the runtime? Runtime state including all local variables will be lost.

I then noticed this comment in the High Level API notebook

# For GPU install proper version of your CUDA, following will work in colab:
! pip install --upgrade jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html

The runtime still freezes after I uncommented it.

Curiously, the Low Level API contains a different command

# For GPU install proper version of your CUDA, following will work in COLAB:
! pip install --upgrade jax jaxlib==0.1.59+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html

After uncommenting it, I got the following error in model.fit

Click to expand
---------------------------------------------------------------------------

AttributeError                            Traceback (most recent call last)

[<ipython-input-6-1f379fc2ddcc>](https://localhost:8080/#) in <module>()
----> 1 from datasets.load import load_dataset
      2 import numpy as np
      3 
      4 dataset = load_dataset("mnist")
      5 dataset.set_format("np")

18 frames

[/usr/local/lib/python3.7/dist-packages/datasets/__init__.py](https://localhost:8080/#) in <module>()
     35 del version
     36 
---> 37 from .arrow_dataset import Dataset, concatenate_datasets
     38 from .arrow_reader import ReadInstruction
     39 from .builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder

[/usr/local/lib/python3.7/dist-packages/datasets/arrow_dataset.py](https://localhost:8080/#) in <module>()
     52 import pyarrow as pa
     53 import pyarrow.compute as pc
---> 54 from huggingface_hub import HfApi, HfFolder
     55 from multiprocess import Pool, RLock
     56 from requests import HTTPError

[/usr/local/lib/python3.7/dist-packages/huggingface_hub/__init__.py](https://localhost:8080/#) in <module>()
     68 from .hub_mixin import ModelHubMixin, PyTorchModelHubMixin
     69 from .inference_api import InferenceApi
---> 70 from .keras_mixin import (
     71     KerasModelHubMixin,
     72     from_pretrained_keras,

[/usr/local/lib/python3.7/dist-packages/huggingface_hub/keras_mixin.py](https://localhost:8080/#) in <module>()
     25 
     26 if is_tf_available():
---> 27     import tensorflow as tf
     28 
     29 

[/usr/local/lib/python3.7/dist-packages/tensorflow/__init__.py](https://localhost:8080/#) in <module>()
     49 from ._api.v2 import autograph
     50 from ._api.v2 import bitwise
---> 51 from ._api.v2 import compat
     52 from ._api.v2 import config
     53 from ._api.v2 import data

[/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/__init__.py](https://localhost:8080/#) in <module>()
     35 import sys as _sys
     36 
---> 37 from . import v1
     38 from . import v2
     39 from tensorflow.python.compat.compat import forward_compatibility_horizon

[/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/__init__.py](https://localhost:8080/#) in <module>()
     28 from . import autograph
     29 from . import bitwise
---> 30 from . import compat
     31 from . import config
     32 from . import data

[/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/compat/__init__.py](https://localhost:8080/#) in <module>()
     35 import sys as _sys
     36 
---> 37 from . import v1
     38 from . import v2
     39 from tensorflow.python.compat.compat import forward_compatibility_horizon

[/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/compat/v1/__init__.py](https://localhost:8080/#) in <module>()
     45 from tensorflow._api.v2.compat.v1 import layers
     46 from tensorflow._api.v2.compat.v1 import linalg
---> 47 from tensorflow._api.v2.compat.v1 import lite
     48 from tensorflow._api.v2.compat.v1 import logging
     49 from tensorflow._api.v2.compat.v1 import lookup

[/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/lite/__init__.py](https://localhost:8080/#) in <module>()
      7 
      8 from . import constants
----> 9 from . import experimental
     10 from tensorflow.lite.python.lite import Interpreter
     11 from tensorflow.lite.python.lite import OpHint

[/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/lite/experimental/__init__.py](https://localhost:8080/#) in <module>()
      6 import sys as _sys
      7 
----> 8 from . import authoring
      9 from tensorflow.lite.python.analyzer import ModelAnalyzer as Analyzer
     10 from tensorflow.lite.python.lite import OpResolverType

[/usr/local/lib/python3.7/dist-packages/tensorflow/_api/v2/compat/v1/lite/experimental/authoring/__init__.py](https://localhost:8080/#) in <module>()
      6 import sys as _sys
      7 
----> 8 from tensorflow.lite.python.authoring.authoring import compatible

[/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/authoring/authoring.py](https://localhost:8080/#) in <module>()
     41 
     42 # pylint: disable=g-import-not-at-top
---> 43 from tensorflow.lite.python import convert
     44 from tensorflow.lite.python import lite
     45 from tensorflow.lite.python.metrics import converter_error_data_pb2

[/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/convert.py](https://localhost:8080/#) in <module>()
     27 
     28 from tensorflow.lite.python import lite_constants
---> 29 from tensorflow.lite.python import util
     30 from tensorflow.lite.python import wrap_toco
     31 from tensorflow.lite.python.convert_phase import Component

[/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/util.py](https://localhost:8080/#) in <module>()
     49 # pylint: disable=unused-import
     50 try:
---> 51   from jax import xla_computation as _xla_computation
     52 except ImportError:
     53   _xla_computation = None

[/usr/local/lib/python3.7/dist-packages/jax/__init__.py](https://localhost:8080/#) in <module>()
     33 # We want the exported object to be the class, so we first import the module
     34 # to make sure a later import doesn't overwrite the class.
---> 35 from jax import config as _config_module
     36 del _config_module
     37 

[/usr/local/lib/python3.7/dist-packages/jax/config.py](https://localhost:8080/#) in <module>()
     15 # TODO(phawkins): fix users of this alias and delete this file.
     16 
---> 17 from jax._src.config import config

[/usr/local/lib/python3.7/dist-packages/jax/_src/config.py](https://localhost:8080/#) in <module>()
     25 import warnings
     26 
---> 27 from jax._src import lib
     28 from jax._src.lib import jax_jit
     29 from jax._src.lib import transfer_guard_lib

[/usr/local/lib/python3.7/dist-packages/jax/_src/lib/__init__.py](https://localhost:8080/#) in <module>()
    101 version_str = jaxlib.version.__version__
    102 version = check_jaxlib_version(
--> 103   jax_version=jax.version.__version__,
    104   jaxlib_version=jaxlib.version.__version__,
    105   minimum_jaxlib_version=jax.version._minimum_jaxlib_version)

AttributeError: module 'jax' has no attribute 'version'

I have also tried using Elegy in the notebook I have been working on, and got another error

Click to expand
Epoch 1/10

---------------------------------------------------------------------------

UnfilteredStackTrace                      Traceback (most recent call last)

[<ipython-input-11-ac2835db233b>](https://localhost:8080/#) in <module>()
      7     validation_data=(test_ds['image'], test_ds['label']),
----> 8     shuffle=True
      9 )

17 frames

[/usr/local/lib/python3.7/dist-packages/elegy/model/model_base.py](https://localhost:8080/#) in fit(self, inputs, labels, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, drop_remaining)
    418                             inputs=inputs,
--> 419                             labels=labels,
    420                         )

[/usr/local/lib/python3.7/dist-packages/elegy/model/model_core.py](https://localhost:8080/#) in train_on_batch(self, inputs, labels)
    616         train_step_fn = self.train_step_fn[self._distributed_strategy]
--> 617         logs, model = train_step_fn(self, inputs, labels)
    618 

[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in cache_miss(*args, **kwargs)
    475         device=device, backend=backend, name=flat_fun.__name__,
--> 476         donated_invars=donated_invars, inline=inline, keep_unused=keep_unused)
    477     out_pytree_def = out_tree()

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, fun, *args, **params)
   1764   def bind(self, fun, *args, **params):
-> 1765     return call_bind(self, fun, *args, **params)
   1766 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in call_bind(primitive, fun, *args, **params)
   1780   fun_ = lu.annotate(fun_, fun.in_type)
-> 1781   outs = top_trace.process_call(primitive, fun_, tracers, params)
   1782   return map(full_lower, apply_todos(env_trace_todo(), outs))

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_call(self, primitive, f, tracers, params)
    677   def process_call(self, primitive, f, tracers, params):
--> 678     return primitive.impl(f, *tracers, **params)
    679   process_map = process_call

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_call_impl(***failed resolving arguments***)
    182   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 183                                keep_unused, *arg_specs)
    184   try:

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in memoized_fun(fun, *args)
    284     else:
--> 285       ans = call(fun, *args)
    286       cache[key] = (ans, fun.stores)

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, device, backend, name, donated_invars, keep_unused, *arg_specs)
    230   return lower_xla_callable(fun, device, backend, name, donated_invars, False,
--> 231                             keep_unused, *arg_specs).compile().unsafe_call
    232 

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in compile(self)
    704         self._executable = XlaCompiledComputation.from_xla_computation(
--> 705             self.name, self._hlo, self._explicit_args, **self.compile_args)
    706 

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in from_xla_computation(name, xla_computation, explicit_args, nreps, device, backend, tuple_args, in_avals, out_avals, effects, kept_var_idx, keepalive)
    805                           "in {elapsed_time} sec"):
--> 806       compiled = compile_or_get_cached(backend, xla_computation, options)
    807     buffer_counts = (None if len(out_avals) == 1 and not config.jax_dynamic_shapes

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in compile_or_get_cached(backend, computation, compile_options)
    767     _dump_ir_to_file(module_name, ir_str)
--> 768   return backend_compile(backend, computation, compile_options)
    769 

[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    205     with TraceAnnotation(name, **decorator_kwargs):
--> 206       return func(*args, **kwargs)
    207     return wrapper

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in backend_compile(backend, built_c, options)
    712   # separately in Python profiling results
--> 713   return backend.compile(built_c, compile_options=options)
    714 

UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv-bw-filter = (f32[5,5,6,16]{1,0,2,3}, u8[0]{0}) custom-call(f32[256,14,14,6]{2,1,3,0} %multiply.9, f32[256,10,10,16]{2,1,3,0} %multiply.6), window={size=5x5}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBackwardFilter", metadata={op_name="jit(_static_train_step)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((0, 0), (0, 0)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(3, 0, 1, 2), rhs_spec=(3, 0, 1, 2), out_spec=(2, 3, 0, 1)) feature_group_count=1 batch_group_count=1 lhs_shape=(256, 14, 14, 6) rhs_shape=(256, 10, 10, 16) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py" source_line=398}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: UNIMPLEMENTED: DNN library is not found.

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------


The above exception was the direct cause of the following exception:

XlaRuntimeError                           Traceback (most recent call last)

[<ipython-input-11-ac2835db233b>](https://localhost:8080/#) in <module>()
      6     batch_size=batch_size,
      7     validation_data=(test_ds['image'], test_ds['label']),
----> 8     shuffle=True
      9 )

[/usr/local/lib/python3.7/dist-packages/elegy/model/model_base.py](https://localhost:8080/#) in fit(self, inputs, labels, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, drop_remaining)
    417                         tmp_logs = self.train_on_batch(
    418                             inputs=inputs,
--> 419                             labels=labels,
    420                         )
    421                         tmp_logs.update({"size": data_handler.batch_size})

[/usr/local/lib/python3.7/dist-packages/elegy/model/model_core.py](https://localhost:8080/#) in train_on_batch(self, inputs, labels)
    615 
    616         train_step_fn = self.train_step_fn[self._distributed_strategy]
--> 617         logs, model = train_step_fn(self, inputs, labels)
    618 
    619         if not isinstance(model, type(self)):

XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv-bw-filter = (f32[5,5,6,16]{1,0,2,3}, u8[0]{0}) custom-call(f32[256,14,14,6]{2,1,3,0} %multiply.9, f32[256,10,10,16]{2,1,3,0} %multiply.6), window={size=5x5}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBackwardFilter", metadata={op_name="jit(_static_train_step)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((0, 0), (0, 0)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(3, 0, 1, 2), rhs_spec=(3, 0, 1, 2), out_spec=(2, 3, 0, 1)) feature_group_count=1 batch_group_count=1 lhs_shape=(256, 14, 14, 6) rhs_shape=(256, 10, 10, 16) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py" source_line=398}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: UNIMPLEMENTED: DNN library is not found.

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.

Minimal code to reproduce

https://colab.research.google.com/github/poets-ai/elegy/blob/master/docs/getting-started/high-level-api.ipynb
https://colab.research.google.com/github/poets-ai/elegy/blob/master/docs/getting-started/low-level-api.ipynb
https://colab.research.google.com/drive/1ZGlTknvwMC8nrrPC_rsSBEGpgcFmVicG?usp=sharing

Expected behavior

Training completes successfully.

Library Info

>>> import elegy
>>> print(elegy.__version__)
0.8.6

Screenshots

Screen Shot 2022-05-16 at 13 36 44

Additional context

I am using a GPU runtime. i.e. Python 3 Google Compute Engine backend (GPU)

@nalzok nalzok added the bug Something isn't working label May 16, 2022
@murphyk
Copy link

murphyk commented May 16, 2022

IIUC, it should not be necessary to install jax or jaxlib on colab, since it is built in.
See eg this lenet_jax notebook.

@nalzok
Copy link
Author

nalzok commented May 16, 2022

That's true. I was using pip install --upgrade to upgrade them to the latest version, since the default JAX version (v0.3.8 as for now) on Colab doesn't work well with Elegy,

---------------------------------------------------------------------------

AttributeError                            Traceback (most recent call last)

[<ipython-input-2-cbb187e7d76c>](https://localhost:8080/#) in <module>()
     11 import treeo as to
     12 import treex as tx
---> 13 import elegy as eg
     14 
     15 from bokeh.resources import INLINE

6 frames

[/usr/local/lib/python3.7/dist-packages/elegy/__init__.py](https://localhost:8080/#) in <module>()
     16 )
     17 
---> 18 from .model.model import Model
     19 from .model.model_base import ModelBase, load
     20 from .model.model_core import (

[/usr/local/lib/python3.7/dist-packages/elegy/model/model.py](https://localhost:8080/#) in <module>()
      9 
     10 from elegy import types, utils
---> 11 from elegy.model.model_base import ModelBase
     12 from elegy.model.model_core import (
     13     GradStepOutput,

[/usr/local/lib/python3.7/dist-packages/elegy/model/model_base.py](https://localhost:8080/#) in <module>()
     19 from elegy.callbacks.sigint import SigIntMode
     20 from elegy.data import utils as data_utils
---> 21 from elegy.model.model_core import ModelCore, PredStepOutput, TestStepOutput
     22 
     23 __all__ = ["ModelBase", "load"]

[/usr/local/lib/python3.7/dist-packages/elegy/model/model_core.py](https://localhost:8080/#) in <module>()
     14 from elegy import types, utils
     15 
---> 16 from . import utils as model_utils
     17 
     18 try:

[/usr/local/lib/python3.7/dist-packages/elegy/model/utils.py](https://localhost:8080/#) in <module>()
      3 try:
      4     import tensorflow as tf  # type: ignore[import]
----> 5     from jax.experimental import jax2tf  # type: ignore[import]
      6 
      7     def convert_and_save_model(

[/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/__init__.py](https://localhost:8080/#) in <module>()
     13 # limitations under the License.
     14 
---> 15 from jax.experimental.jax2tf.jax2tf import (convert, dtype_of_val,
     16                                             split_to_logical_devices, PolyShape)
     17 from jax.experimental.jax2tf.call_tf import call_tf

[/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/jax2tf.py](https://localhost:8080/#) in <module>()
   2388                     extra_name_stack="checkpoint")
   2389 
-> 2390 tf_impl[lax_control_flow.optimization_barrier_p] = tfxla.optimization_barrier
   2391 
   2392 def _top_k(operand: TfVal, k: int) -> Tuple[TfVal, TfVal]:

AttributeError: module 'tensorflow.compiler.tf2xla.python.xla' has no attribute 'optimization_barrier'

@cgarciae
Copy link
Collaborator

@nalzok thanks for reporting this!
These notebooks are tested on CI but sadly testing for colab is a manual process. Will try to give it a go but if you find the fix it would be amazing if you can contribute it back :)

@nalzok
Copy link
Author

nalzok commented May 23, 2022

Yeah, I am willing to help but I cannot figure out how to install a package from GitHub. I just created a fork at https://github.com/nalzok/elegy and tried to install it on Colab with

! pip install --upgrade pip
! pip install git+https://github.com/nalzok/elegy

Then I got datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible..

Full error message (click to expand)
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: pip in /usr/local/lib/python3.7/dist-packages (21.1.3)
Collecting pip
  Downloading pip-22.1.1-py3-none-any.whl (2.1 MB)
     |████████████████████████████████| 2.1 MB 7.9 MB/s 
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-22.1.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/nalzok/elegy
  Cloning https://github.com/nalzok/elegy to /tmp/pip-req-build-fbzdwabs
  Running command git clone --filter=blob:none --quiet https://github.com/nalzok/elegy /tmp/pip-req-build-fbzdwabs
  Resolved https://github.com/nalzok/elegy to commit 4709ce8dc9dde3925ce717e2358ce49112e36398
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
Collecting treex<0.7.0,>=0.6.5
  Downloading treex-0.6.10-py3-none-any.whl (111 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 111.7/111.7 kB 5.8 MB/s eta 0:00:00
Collecting tensorboardx<3.0,>=2.1
  Downloading tensorboardX-2.5-py2.py3-none-any.whl (125 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 125.3/125.3 kB 10.5 MB/s eta 0:00:00
Collecting wandb<0.13.0,>=0.12.10
  Downloading wandb-0.12.16-py2.py3-none-any.whl (1.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 MB 36.9 MB/s eta 0:00:00
Collecting cloudpickle<2.0.0,>=1.5.0
  Downloading cloudpickle-1.6.0-py3-none-any.whl (23 kB)
Requirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.7/dist-packages (from tensorboardx<3.0,>=2.1->elegy==0.8.6) (3.17.3)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from tensorboardx<3.0,>=2.1->elegy==0.8.6) (1.15.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from tensorboardx<3.0,>=2.1->elegy==0.8.6) (1.21.6)
Collecting rich<12.0.0,>=11.2.0
  Downloading rich-11.2.0-py3-none-any.whl (217 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 217.3/217.3 kB 26.2 MB/s eta 0:00:00
Collecting PyYAML<7.0,>=6.0
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 596.3/596.3 kB 47.9 MB/s eta 0:00:00
Collecting flax<0.5.0,>=0.4.0
  Downloading flax-0.4.2-py3-none-any.whl (186 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 186.4/186.4 kB 21.9 MB/s eta 0:00:00
Collecting treeo<0.0.11,>=0.0.10
  Downloading treeo-0.0.10-py3-none-any.whl (17 kB)
Collecting einops<0.5.0,>=0.4.0
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Collecting certifi<2022.0.0,>=2021.10.8
  Downloading certifi-2021.10.8-py2.py3-none-any.whl (149 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 149.2/149.2 kB 20.7 MB/s eta 0:00:00
Collecting optax<0.2.0,>=0.1.1
  Downloading optax-0.1.2-py3-none-any.whl (140 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 140.7/140.7 kB 21.2 MB/s eta 0:00:00
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
  Preparing metadata (setup.py) ... done
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.5.12-py2.py3-none-any.whl (145 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 145.3/145.3 kB 17.7 MB/s eta 0:00:00
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from wandb<0.13.0,>=0.12.10->elegy==0.8.6) (57.4.0)
Collecting setproctitle
  Downloading setproctitle-1.2.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29 kB)
Requirement already satisfied: Click!=8.0.0,>=7.0 in /usr/local/lib/python3.7/dist-packages (from wandb<0.13.0,>=0.12.10->elegy==0.8.6) (7.1.2)
Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb<0.13.0,>=0.12.10->elegy==0.8.6) (2.23.0)
Collecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.9-py3-none-any.whl (9.4 kB)
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.7/dist-packages (from wandb<0.13.0,>=0.12.10->elegy==0.8.6) (2.8.2)
Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb<0.13.0,>=0.12.10->elegy==0.8.6) (5.4.8)
Requirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.7/dist-packages (from wandb<0.13.0,>=0.12.10->elegy==0.8.6) (2.3)
Collecting GitPython>=1.0.0
  Downloading GitPython-3.1.27-py3-none-any.whl (181 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 181.2/181.2 kB 22.3 MB/s eta 0:00:00
Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (3.2.2)
Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (1.0.3)
Requirement already satisfied: jax>=0.3 in /usr/local/lib/python3.7/dist-packages (from flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (0.3.8)
Requirement already satisfied: typing-extensions>=4.1.1 in /usr/local/lib/python3.7/dist-packages (from flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (4.2.0)
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-4.0.9-py3-none-any.whl (63 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.1/63.1 kB 7.2 MB/s eta 0:00:00
Collecting chex>=0.0.4
  Downloading chex-0.1.3-py3-none-any.whl (72 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 72.2/72.2 kB 10.8 MB/s eta 0:00:00
Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax<0.2.0,>=0.1.1->treex<0.7.0,>=0.6.5->elegy==0.8.6) (1.0.0)
Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax<0.2.0,>=0.1.1->treex<0.7.0,>=0.6.5->elegy==0.8.6) (0.3.7+cuda11.cudnn805)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb<0.13.0,>=0.12.10->elegy==0.8.6) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb<0.13.0,>=0.12.10->elegy==0.8.6) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb<0.13.0,>=0.12.10->elegy==0.8.6) (1.24.3)
Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.7/dist-packages (from rich<12.0.0,>=11.2.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (2.6.1)
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 51.1/51.1 kB 7.5 MB/s eta 0:00:00
Collecting colorama<0.5.0,>=0.4.0
  Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)
Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax<0.2.0,>=0.1.1->treex<0.7.0,>=0.6.5->elegy==0.8.6) (0.1.7)
Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax<0.2.0,>=0.1.1->treex<0.7.0,>=0.6.5->elegy==0.8.6) (0.11.2)
Collecting smmap<6,>=3.0.1
  Downloading smmap-5.0.0-py3-none-any.whl (24 kB)
Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (1.4.1)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (3.3.0)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax<0.2.0,>=0.1.1->treex<0.7.0,>=0.6.5->elegy==0.8.6) (2.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (1.4.2)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (0.11.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax<0.5.0,>=0.4.0->treex<0.7.0,>=0.6.5->elegy==0.8.6) (3.0.9)
Building wheels for collected packages: elegy, pathtools
  Building wheel for elegy (pyproject.toml) ... done
  Created wheel for elegy: filename=elegy-0.8.6-py3-none-any.whl size=72228 sha256=cbaac711df1e4b92557b49daf5b8f819f6505a86f8e130784b7d66ea1636e41d
  Stored in directory: /tmp/pip-ephem-wheel-cache-rq_lvte2/wheels/71/e7/f6/574c5a5046b672581176a5d22b710ded1fd1db6715b187d363
  Building wheel for pathtools (setup.py) ... done
  Created wheel for pathtools: filename=pathtools-0.1.2-py3-none-any.whl size=8806 sha256=b6b282cad3d6d596fec3fdb9350ef8774617a58d5f05a795781d3df8a9b1d850
  Stored in directory: /root/.cache/pip/wheels/3e/31/09/fa59cef12cdcfecc627b3d24273699f390e71828921b2cbba2
Successfully built elegy pathtools
Installing collected packages: pathtools, einops, commonmark, certifi, treeo, smmap, shortuuid, setproctitle, sentry-sdk, PyYAML, docker-pycreds, colorama, cloudpickle, tensorboardx, rich, gitdb, GitPython, chex, wandb, optax, flax, treex, elegy
  Attempting uninstall: certifi
    Found existing installation: certifi 2022.5.18.1
    Uninstalling certifi-2022.5.18.1:
      Successfully uninstalled certifi-2022.5.18.1
  Attempting uninstall: PyYAML
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
  Attempting uninstall: cloudpickle
    Found existing installation: cloudpickle 1.3.0
    Uninstalling cloudpickle-1.3.0:
      Successfully uninstalled cloudpickle-1.3.0
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
Successfully installed GitPython-3.1.27 PyYAML-6.0 certifi-2021.10.8 chex-0.1.3 cloudpickle-1.6.0 colorama-0.4.4 commonmark-0.9.1 docker-pycreds-0.4.0 einops-0.4.1 elegy-0.8.6 flax-0.4.2 gitdb-4.0.9 optax-0.1.2 pathtools-0.1.2 rich-11.2.0 sentry-sdk-1.5.12 setproctitle-1.2.3 shortuuid-1.0.9 smmap-5.0.0 tensorboardx-2.5 treeo-0.0.10 treex-0.6.10 wandb-0.12.16
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.2.2-py3-none-any.whl (346 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 346.8/346.8 kB 12.6 MB/s eta 0:00:00
Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (3.2.2)
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from datasets) (21.3)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2.23.0)
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 49.3 MB/s eta 0:00:00
Collecting huggingface-hub<1.0.0,>=0.1.0
  Downloading huggingface_hub-0.6.0-py3-none-any.whl (84 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 84.4/84.4 kB 11.7 MB/s eta 0:00:00
Collecting dill<0.3.5
  Downloading dill-0.3.4-py2.py3-none-any.whl (86 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 86.9/86.9 kB 8.7 MB/s eta 0:00:00
Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0.1)
Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.12.2)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.3.5)
Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (4.64.0)
Collecting xxhash
  Downloading xxhash-3.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 212.2/212.2 kB 25.7 MB/s eta 0:00:00
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting fsspec[http]>=2021.05.0
  Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 140.6/140.6 kB 19.7 MB/s eta 0:00:00
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from datasets) (1.21.6)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from datasets) (4.11.3)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib) (2.8.2)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib) (1.4.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib) (3.0.9)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib) (0.11.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (3.7.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (4.2.0)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (6.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib) (1.15.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2021.10.8)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (1.24.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2.10)
Collecting urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1
  Downloading urllib3-1.25.11-py2.py3-none-any.whl (127 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 128.0/128.0 kB 17.2 MB/s eta 0:00:00
Collecting async-timeout<5.0,>=4.0.0a3
  Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)
Collecting frozenlist>=1.1.1
  Downloading frozenlist-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (144 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 144.8/144.8 kB 1.9 MB/s eta 0:00:00
Collecting aiosignal>=1.1.2
  Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Collecting multidict<7.0,>=4.5
  Downloading multidict-6.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (94 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 94.8/94.8 kB 13.9 MB/s eta 0:00:00
Collecting asynctest==0.13.0
  Downloading asynctest-0.13.0-py3-none-any.whl (26 kB)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (2.0.12)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (21.4.0)
Collecting yarl<2.0,>=1.0
  Downloading yarl-1.7.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 271.8/271.8 kB 31.2 MB/s eta 0:00:00
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->datasets) (3.8.0)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2022.1)
Installing collected packages: xxhash, urllib3, multidict, fsspec, frozenlist, dill, asynctest, async-timeout, yarl, aiosignal, responses, huggingface-hub, aiohttp, datasets
  Attempting uninstall: urllib3
    Found existing installation: urllib3 1.24.3
    Uninstalling urllib3-1.24.3:
      Successfully uninstalled urllib3-1.24.3
  Attempting uninstall: dill
    Found existing installation: dill 0.3.5.1
    Uninstalling dill-0.3.5.1:
      Successfully uninstalled dill-0.3.5.1
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
Successfully installed aiohttp-3.8.1 aiosignal-1.2.0 async-timeout-4.0.2 asynctest-0.13.0 datasets-2.2.2 dill-0.3.4 frozenlist-1.3.0 fsspec-2022.5.0 huggingface-hub-0.6.0 multidict-6.0.2 responses-0.18.0 urllib3-1.25.11 xxhash-3.0.0 yarl-1.7.2
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

@jiyuuchc
Copy link

Looked into this a bit, since I was doing some testings on colab.

It seems calling reset_metrics() resulted in a hang for any future calls to any of the JITed model functions. This can be demonstrated by overriding the reset_metrics()

def do_nothing():
  pass

model.reset_metrics = do_nothing

The training will finish after this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants