@@ -587,7 +587,7 @@ defmodule Axon do
587
587
end
588
588
589
589
@ doc """
590
- Implements an or else (e.g. an Elixir ||)
590
+ Implements an or else (e.g. an Elixir ||)
591
591
"""
592
592
@ doc type: :special
593
593
def or_else ( % Axon { } = a , % Axon { } = b , opts \\ [ ] ) do
@@ -3771,7 +3771,7 @@ defmodule Axon do
3771
3771
as input and returns a function that replaces or rewrites the given node.
3772
3772
For example, you can define a simple rewriter which replaces the `:relu`
3773
3773
layers with `:tanh` layers:
3774
-
3774
+
3775
3775
tanh_rewriter = fn [%Axon{} = x], _output ->
3776
3776
Axon.relu(x)
3777
3777
end
@@ -3926,13 +3926,16 @@ defmodule Axon do
3926
3926
end
3927
3927
3928
3928
@ doc """
3929
- Compiles the given model to `{init_fn , predict_fn}`.
3929
+ Compiles the given model to `{init_params , predict_fn}`.
3930
3930
3931
3931
This function will compile a model specialized to the given
3932
3932
input shapes and types. This is useful for avoiding the overhead
3933
3933
of long compilations at program runtime. You must provide template
3934
3934
inputs which match the expected shapes and types of inputs at
3935
- execution time.
3935
+ execution time. Depending on the Nx compiler, such as EXLA v0.9.1+,
3936
+ both `init_params` the `predict_fn` can be sent across nodes, as
3937
+ long the node that owns them keeps a reference to the underlying
3938
+ resources.
3936
3939
3937
3940
This function makes use of the built-in `Nx.Defn.compile/3`. Note
3938
3941
that passing inputs which differ in shape or type from the templates
@@ -3946,7 +3949,12 @@ defmodule Axon do
3946
3949
def compile ( model , template , init_params \\ Axon.ModelState . empty ( ) , opts \\ [ ] )
3947
3950
when is_list ( opts ) do
3948
3951
{ init_fn , predict_fn } = build ( model , opts )
3949
- init_params = Nx.Defn . jit_apply ( init_fn , [ template , Axon.ModelState . new ( init_params ) ] , opts )
3952
+ model_state = Axon.ModelState . new ( init_params )
3953
+
3954
+ # If there is a disk cache, we only want it to apply to the predict function
3955
+ init_opts = if is_binary ( opts [ :cache ] ) , do: Keyword . delete ( opts , :cache ) , else: opts
3956
+ init_params = Nx.Defn . jit_apply ( init_fn , [ template , model_state ] , init_opts )
3957
+
3950
3958
predict_compiled_fn = Nx.Defn . compile ( predict_fn , [ init_params , template ] , opts )
3951
3959
{ init_params , predict_compiled_fn }
3952
3960
end
0 commit comments