Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSVLogger iteration over a 0-d array #248

Open
ScottAlexanderCameron opened this issue Sep 28, 2022 · 0 comments
Open

CSVLogger iteration over a 0-d array #248

ScottAlexanderCameron opened this issue Sep 28, 2022 · 0 comments
Labels
bug Something isn't working

Comments

@ScottAlexanderCameron
Copy link

ScottAlexanderCameron commented Sep 28, 2022

Describe the bug
When using the CSVLogger callback, elegy crashes at the end of the first epoch.

Minimal code to reproduce

import elegy as eg
import optax
import numpy as np

x = np.random.randn(64, 1)
y = np.random.randn(64, 1)

model = eg.Model(
    eg.nn.Linear(1),
    loss=eg.losses.MeanSquaredError(),
    optimizer=optax.adam(1e-3),
)

hist = model.fit(
    x,
    y,
    epochs=10,
    callbacks=[
        eg.callbacks.CSVLogger("train.csv"), <-- commenting
    ]
)

Stack trace:

Epoch 1/10
2/2 [==============================] - ETA: 0s - loss: 1.3408 - mean_squared_error_loss: 1.3408
Traceback (most recent call last):
  File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/scott/Documents/phd/geom/pde/csv.py", line 14, in <module>
    hist = model.fit(
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_base.py", line 465, in fit
    callbacks.on_epoch_end(epoch, epoch_logs)
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/callback_list.py", line 221, in on_epoch_end
    callback.on_epoch_end(epoch, logs)
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 93, in on_epoch_end
    row_dict.update((key, handle_value(logs[key])) for key in self.keys)
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 93, in <genexpr>
    row_dict.update((key, handle_value(logs[key])) for key in self.keys)
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 68, in handle_value
    return '"[%s]"' % (", ".join(map(str, k)))
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/jax/_src/device_array.py", line 245, in __iter__
    raise TypeError("iteration over a 0-d array")  # same as numpy error
TypeError: iteration over a 0-d array

Expected behavior
Should treat 0-d array as scalar.

Library Info
Please provide os info and elegy version.
python version: 3.8.13
elegy version: 0.8.6
treex version: 0.6.10

Additional context
More detailed error information shows the error occurs because the array is a jax DeviceArray and so the test for zero dimensional array uses the line

is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
│ /home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py:6 │
│ 8 in handle_value                                                                                │
│                                                                                                  │
│    65 │   │   │   if isinstance(k, six.string_types):                                            │
│    66 │   │   │   │   return k                                                                   │
│    67 │   │   │   elif isinstance(k, tp.Iterable) and not is_zero_dim_ndarray:                   │
│ ❱  68 │   │   │   │   return '"[%s]"' % (", ".join(map(str, k)))                                 │
│    69 │   │   │   else:                                                                          │
│    70 │   │   │   │   return k                                                                   │
│    71                                                                                            │
│                                                                                                  │
│ ╭──────────────────────────── locals ─────────────────────────────╮                              │
│ │ is_zero_dim_ndarray = False                                     │                              │
│ │                   k = DeviceArray(4.8264385e-05, dtype=float32) │                              │
│ ╰─────────────────────────────────────────────────────────────────╯                              │
│                                                                                                  │
│ /home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/jax/_src/device_array.py:245 in │
│ __iter__                                                                                         │
│                                                                                                  │
│   242                                                                                            │
│   243   def __iter__(self):                                                                      │
│   244 │   if self.ndim == 0:                                                                     │
│ ❱ 245 │     raise TypeError("iteration over a 0-d array")  # same as numpy error                 │
│   246 │   else:                                                                                  │
│   247 │     return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack())            │
│   248                                                                                            │
│                                                                                                  │
│ ╭───────────────────── locals ─────────────────────╮                                             │
│ │ self = DeviceArray(4.8264385e-05, dtype=float32) │                                             │
│ ╰──────────────────────────────────────────────────╯                                             │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: iteration over a 0-d array
@ScottAlexanderCameron ScottAlexanderCameron added the bug Something isn't working label Sep 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant