File tree 2 files changed +3
-9
lines changed
2 files changed +3
-9
lines changed Original file line number Diff line number Diff line change @@ -367,14 +367,8 @@ defmodule Axon.Loop do
367
367
end )
368
368
369
369
model_out = forward_model_fn . ( model_state , inp )
370
-
371
- { scaled_loss , unscaled_loss } =
372
- tar
373
- |> loss_fn . ( model_out . prediction )
374
- |> then ( fn loss ->
375
- scaled = scale_loss . ( loss , loss_scale_state )
376
- { scaled , loss }
377
- end )
370
+ unscaled_loss = loss_fn . ( tar , model_out . prediction )
371
+ scaled_loss = scale_loss . ( unscaled_loss , loss_scale_state )
378
372
379
373
{ model_out , scaled_loss , unscaled_loss }
380
374
end
Original file line number Diff line number Diff line change @@ -62,7 +62,7 @@ defmodule Axon.LoopTest do
62
62
63
63
test "trainer/3 returns a supervised training loop with custom loss" do
64
64
model = Axon . input ( "input" , shape: { nil , 1 } )
65
- custom_loss_fn = fn _ , _ -> Nx . tensor ( 5.0 , backend: Nx.BinaryBackend ) end
65
+ custom_loss_fn = fn _ , _ -> Nx . tensor ( 5.0 , backend: Nx.Defn.Expr ) end
66
66
67
67
assert % Loop { init: init_fn , step: update_fn , output_transform: transform } =
68
68
Loop . trainer ( model , custom_loss_fn , :adam )
You can’t perform that action at this time.
0 commit comments