Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metrics ignore "on" keyword arg #247

Open
ScottAlexanderCameron opened this issue Sep 28, 2022 · 0 comments
Open

Metrics ignore "on" keyword arg #247

ScottAlexanderCameron opened this issue Sep 28, 2022 · 0 comments
Labels
bug Something isn't working

Comments

@ScottAlexanderCameron
Copy link

Describe the bug
I have an application where I need to output multiple values from a network, which I am doing using a dictionary and using the on keyword argument. This works fine for the loss functions but not for metrics.

Minimal code to reproduce
Small snippet that contains a minimal amount of code.

import elegy as eg
import optax
import numpy as np


def data_generator():
    while True:
        yield (
            np.random.randn(10, 1),
            {"target": {"y": np.random.randn(10, 1)}},
        )


class MyModule(eg.Module):
    @eg.compact
    def __call__(self, x):
        return {"y": eg.nn.Linear(1)(x)}


model = eg.Model(
    MyModule(),
    loss=eg.losses.MeanSquaredError(on="y"),
    metrics=eg.metrics.MeanAbsoluteError(on="y"),  #  <-- works fine without this line
    optimizer=optax.adam(1e-3),
)

hist = model.fit(
    data_generator(),
    steps_per_epoch=10,
    epochs=10,
)

Stack trace:

Traceback (most recent call last):
  File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/scott/Documents/phd/geom/pde/metric.py", line 27, in <module>
    hist = model.fit(
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_base.py", line 417, in fit
    tmp_logs = self.train_on_batch(
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_core.py", line 617, in train_on_batch
    logs, model = train_step_fn(self, inputs, labels)
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_core.py", line 412, in _static_train_step
    return model.train_step(inputs, labels)
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 306, in train_step
    grads, (logs, model) = grad_fn(params, model, inputs, labels)
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 278, in loss_fn
    loss, logs, model = model.test_step(inputs, labels)
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 248, in test_step
    batch_loss_and_logs.update(
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/loss_and_logs.py", line 78, in update
    self.metrics.update(**metrics_kwargs)
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/metrics.py", line 44, in update
    metric.update(**metric_kwargs)
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/mean_absolute_error.py", line 83, in update
    values = _mean_absolute_error(preds, target)
  File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/mean_absolute_error.py", line 20, in _mean_absolute_error
    target = target.astype(preds.dtype)
AttributeError: 'dict' object has no attribute 'astype'

Expected behavior
Should produce the same result as if the dictionaries are removed and the on arg not specified.

Library Info
Please provide os info and elegy version.
python version: 3.8.13
elegy version: 0.8.6
treex version: 0.6.10

Additional context
From my digging the cause seems to be due to the Metric.update() method being called instead of the __call__ method.

@ScottAlexanderCameron ScottAlexanderCameron added the bug Something isn't working label Sep 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant