You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
I have an application where I need to output multiple values from a network, which I am doing using a dictionary and using the on keyword argument. This works fine for the loss functions but not for metrics.
Minimal code to reproduce
Small snippet that contains a minimal amount of code.
importelegyasegimportoptaximportnumpyasnpdefdata_generator():
whileTrue:
yield (
np.random.randn(10, 1),
{"target": {"y": np.random.randn(10, 1)}},
)
classMyModule(eg.Module):
@eg.compactdef__call__(self, x):
return {"y": eg.nn.Linear(1)(x)}
model=eg.Model(
MyModule(),
loss=eg.losses.MeanSquaredError(on="y"),
metrics=eg.metrics.MeanAbsoluteError(on="y"), # <-- works fine without this lineoptimizer=optax.adam(1e-3),
)
hist=model.fit(
data_generator(),
steps_per_epoch=10,
epochs=10,
)
Stack trace:
Traceback (most recent call last):
File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/scott/Documents/phd/geom/pde/metric.py", line 27, in <module>
hist = model.fit(
File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_base.py", line 417, in fit
tmp_logs = self.train_on_batch(
File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_core.py", line 617, in train_on_batch
logs, model = train_step_fn(self, inputs, labels)
File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_core.py", line 412, in _static_train_step
return model.train_step(inputs, labels)
File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 306, in train_step
grads, (logs, model) = grad_fn(params, model, inputs, labels)
File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 278, in loss_fn
loss, logs, model = model.test_step(inputs, labels)
File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 248, in test_step
batch_loss_and_logs.update(
File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/loss_and_logs.py", line 78, in update
self.metrics.update(**metrics_kwargs)
File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/metrics.py", line 44, in update
metric.update(**metric_kwargs)
File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/mean_absolute_error.py", line 83, in update
values = _mean_absolute_error(preds, target)
File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/mean_absolute_error.py", line 20, in _mean_absolute_error
target = target.astype(preds.dtype)
AttributeError: 'dict' object has no attribute 'astype'
Expected behavior
Should produce the same result as if the dictionaries are removed and the on arg not specified.
Library Info
Please provide os info and elegy version.
python version: 3.8.13
elegy version: 0.8.6
treex version: 0.6.10
Additional context
From my digging the cause seems to be due to the Metric.update() method being called instead of the __call__ method.
The text was updated successfully, but these errors were encountered:
Describe the bug
I have an application where I need to output multiple values from a network, which I am doing using a dictionary and using the
on
keyword argument. This works fine for the loss functions but not for metrics.Minimal code to reproduce
Small snippet that contains a minimal amount of code.
Stack trace:
Expected behavior
Should produce the same result as if the dictionaries are removed and the
on
arg not specified.Library Info
Please provide os info and elegy version.
python version: 3.8.13
elegy version: 0.8.6
treex version: 0.6.10
Additional context
From my digging the cause seems to be due to the
Metric.update()
method being called instead of the__call__
method.The text was updated successfully, but these errors were encountered: