Skip to content

Commit 603818f

Browse files
committed
Fix failing test
1 parent 4851084 commit 603818f

File tree

2 files changed

+3
-9
lines changed

2 files changed

+3
-9
lines changed

lib/axon/loop.ex

+2-8
Original file line numberDiff line numberDiff line change
@@ -367,14 +367,8 @@ defmodule Axon.Loop do
367367
end)
368368

369369
model_out = forward_model_fn.(model_state, inp)
370-
371-
{scaled_loss, unscaled_loss} =
372-
tar
373-
|> loss_fn.(model_out.prediction)
374-
|> then(fn loss ->
375-
scaled = scale_loss.(loss, loss_scale_state)
376-
{scaled, loss}
377-
end)
370+
unscaled_loss = loss_fn.(tar, model_out.prediction)
371+
scaled_loss = scale_loss.(unscaled_loss, loss_scale_state)
378372

379373
{model_out, scaled_loss, unscaled_loss}
380374
end

test/axon/loop_test.exs

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ defmodule Axon.LoopTest do
6262

6363
test "trainer/3 returns a supervised training loop with custom loss" do
6464
model = Axon.input("input", shape: {nil, 1})
65-
custom_loss_fn = fn _, _ -> Nx.tensor(5.0, backend: Nx.BinaryBackend) end
65+
custom_loss_fn = fn _, _ -> Nx.tensor(5.0, backend: Nx.Defn.Expr) end
6666

6767
assert %Loop{init: init_fn, step: update_fn, output_transform: transform} =
6868
Loop.trainer(model, custom_loss_fn, :adam)

0 commit comments

Comments
 (0)