@@ -325,7 +325,6 @@ defmodule Axon.Loop do
325
325
"""
326
326
def train_step ( model , loss , optimizer , opts \\ [ ] ) do
327
327
opts = Keyword . validate! ( opts , [ :seed , loss_scale: :identity ] )
328
-
329
328
loss_scale = opts [ :loss_scale ] || :identity
330
329
331
330
{ init_model_fn , forward_model_fn } = build_model_fns ( model , :train , opts )
@@ -341,8 +340,7 @@ defmodule Axon.Loop do
341
340
optimizer_state = init_optimizer_fn . ( trainable_parameters )
342
341
loss_scale_state = init_loss_scale . ( )
343
342
344
- # TODO: is this expensive? Will it compute the entire
345
- # forward?
343
+ # TODO: is this expensive? Will it compute the entire forward?
346
344
% { prediction: output } = forward_model_fn . ( model_state , inp )
347
345
348
346
% {
@@ -507,6 +505,7 @@ defmodule Axon.Loop do
507
505
raise_bad_training_inputs! ( data , state )
508
506
end
509
507
508
+ # Pass on_conflict: :reuse as we want someone to jit it on top
510
509
{
511
510
Nx.Defn . jit ( init_fn , on_conflict: :reuse ) ,
512
511
Nx.Defn . jit ( step_fn , on_conflict: :reuse )
@@ -1563,9 +1562,9 @@ defmodule Axon.Loop do
1563
1562
* `:debug` - run loop in debug mode to trace loop progress. Defaults to
1564
1563
false.
1565
1564
1566
- Additional options are forwarded to `Nx.Defn.jit` as JIT-options. If no JIT
1567
- options are set, the default options set with `Nx.Defn.default_options` are
1568
- used.
1565
+ Additional options are forwarded to `Nx.Defn.jit` as JIT-options. If no JIT
1566
+ options are set, the default options set with `Nx.Defn.default_options` are
1567
+ used.
1569
1568
"""
1570
1569
def run ( loop , data , init_state \\ % { } , opts \\ [ ] ) do
1571
1570
{ max_epochs , opts } = Keyword . pop ( opts , :epochs , 1 )
@@ -2263,6 +2262,10 @@ defmodule Axon.Loop do
2263
2262
# otherwise just applies the function with the given arguments
2264
2263
defp maybe_jit ( fun , args , jit_compile? , jit_opts ) do
2265
2264
if jit_compile? do
2265
+ # If there is a disk cache, we only want it to apply to the batch function
2266
+ jit_opts =
2267
+ if is_binary ( jit_opts [ :cache ] ) , do: Keyword . delete ( jit_opts , :cache ) , else: jit_opts
2268
+
2266
2269
apply ( Nx.Defn . jit ( fun , jit_opts ) , args )
2267
2270
else
2268
2271
apply ( fun , args )
0 commit comments