Skip to content

Commit 7196ddb

Browse files
committed
Discard disk cache for step function
1 parent f90e5ac commit 7196ddb

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

lib/axon/loop.ex

+9-6
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,6 @@ defmodule Axon.Loop do
325325
"""
326326
def train_step(model, loss, optimizer, opts \\ []) do
327327
opts = Keyword.validate!(opts, [:seed, loss_scale: :identity])
328-
329328
loss_scale = opts[:loss_scale] || :identity
330329

331330
{init_model_fn, forward_model_fn} = build_model_fns(model, :train, opts)
@@ -341,8 +340,7 @@ defmodule Axon.Loop do
341340
optimizer_state = init_optimizer_fn.(trainable_parameters)
342341
loss_scale_state = init_loss_scale.()
343342

344-
# TODO: is this expensive? Will it compute the entire
345-
# forward?
343+
# TODO: is this expensive? Will it compute the entire forward?
346344
%{prediction: output} = forward_model_fn.(model_state, inp)
347345

348346
%{
@@ -507,6 +505,7 @@ defmodule Axon.Loop do
507505
raise_bad_training_inputs!(data, state)
508506
end
509507

508+
# Pass on_conflict: :reuse as we want someone to jit it on top
510509
{
511510
Nx.Defn.jit(init_fn, on_conflict: :reuse),
512511
Nx.Defn.jit(step_fn, on_conflict: :reuse)
@@ -1563,9 +1562,9 @@ defmodule Axon.Loop do
15631562
* `:debug` - run loop in debug mode to trace loop progress. Defaults to
15641563
false.
15651564
1566-
Additional options are forwarded to `Nx.Defn.jit` as JIT-options. If no JIT
1567-
options are set, the default options set with `Nx.Defn.default_options` are
1568-
used.
1565+
Additional options are forwarded to `Nx.Defn.jit` as JIT-options. If no JIT
1566+
options are set, the default options set with `Nx.Defn.default_options` are
1567+
used.
15691568
"""
15701569
def run(loop, data, init_state \\ %{}, opts \\ []) do
15711570
{max_epochs, opts} = Keyword.pop(opts, :epochs, 1)
@@ -2263,6 +2262,10 @@ defmodule Axon.Loop do
22632262
# otherwise just applies the function with the given arguments
22642263
defp maybe_jit(fun, args, jit_compile?, jit_opts) do
22652264
if jit_compile? do
2265+
# If there is a disk cache, we only want it to apply to the batch function
2266+
jit_opts =
2267+
if is_binary(jit_opts[:cache]), do: Keyword.delete(jit_opts, :cache), else: jit_opts
2268+
22662269
apply(Nx.Defn.jit(fun, jit_opts), args)
22672270
else
22682271
apply(fun, args)

0 commit comments

Comments
 (0)