Skip to content

Commit e19edf1

Browse files
committed
Discard cache on init_params computation
1 parent 7196ddb commit e19edf1

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

lib/axon.ex

+13-5
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ defmodule Axon do
587587
end
588588

589589
@doc """
590-
Implements an or else (e.g. an Elixir ||)
590+
Implements an or else (e.g. an Elixir ||)
591591
"""
592592
@doc type: :special
593593
def or_else(%Axon{} = a, %Axon{} = b, opts \\ []) do
@@ -3771,7 +3771,7 @@ defmodule Axon do
37713771
as input and returns a function that replaces or rewrites the given node.
37723772
For example, you can define a simple rewriter which replaces the `:relu`
37733773
layers with `:tanh` layers:
3774-
3774+
37753775
tanh_rewriter = fn [%Axon{} = x], _output ->
37763776
Axon.relu(x)
37773777
end
@@ -3926,13 +3926,16 @@ defmodule Axon do
39263926
end
39273927

39283928
@doc """
3929-
Compiles the given model to `{init_fn, predict_fn}`.
3929+
Compiles the given model to `{init_params, predict_fn}`.
39303930
39313931
This function will compile a model specialized to the given
39323932
input shapes and types. This is useful for avoiding the overhead
39333933
of long compilations at program runtime. You must provide template
39343934
inputs which match the expected shapes and types of inputs at
3935-
execution time.
3935+
execution time. Depending on the Nx compiler, such as EXLA v0.9.1+,
3936+
both `init_params` the `predict_fn` can be sent across nodes, as
3937+
long the node that owns them keeps a reference to the underlying
3938+
resources.
39363939
39373940
This function makes use of the built-in `Nx.Defn.compile/3`. Note
39383941
that passing inputs which differ in shape or type from the templates
@@ -3946,7 +3949,12 @@ defmodule Axon do
39463949
def compile(model, template, init_params \\ Axon.ModelState.empty(), opts \\ [])
39473950
when is_list(opts) do
39483951
{init_fn, predict_fn} = build(model, opts)
3949-
init_params = Nx.Defn.jit_apply(init_fn, [template, Axon.ModelState.new(init_params)], opts)
3952+
model_state = Axon.ModelState.new(init_params)
3953+
3954+
# If there is a disk cache, we only want it to apply to the predict function
3955+
init_opts = if is_binary(opts[:cache]), do: Keyword.delete(opts, :cache), else: opts
3956+
init_params = Nx.Defn.jit_apply(init_fn, [template, model_state], init_opts)
3957+
39503958
predict_compiled_fn = Nx.Defn.compile(predict_fn, [init_params, template], opts)
39513959
{init_params, predict_compiled_fn}
39523960
end

0 commit comments

Comments
 (0)