@@ -3943,9 +3943,10 @@ defmodule Axon do
3943
3943
It accepts the same options as `build/2`.
3944
3944
"""
3945
3945
@ doc type: :model
3946
- def compile ( model , template , init_params \\ % { } , opts \\ [ ] ) when is_list ( opts ) do
3946
+ def compile ( model , template , init_params \\ Axon.ModelState . empty ( ) , opts \\ [ ] )
3947
+ when is_list ( opts ) do
3947
3948
{ init_fn , predict_fn } = build ( model , opts )
3948
- init_params = Nx.Defn . jit_apply ( init_fn , [ template , init_params ] , opts )
3949
+ init_params = Nx.Defn . jit_apply ( init_fn , [ template , Axon.ModelState . new ( init_params ) ] , opts )
3949
3950
predict_compiled_fn = Nx.Defn . compile ( predict_fn , [ init_params , template ] , opts )
3950
3951
{ init_params , predict_compiled_fn }
3951
3952
end
@@ -3976,7 +3977,7 @@ defmodule Axon do
3976
3977
@ doc type: :debug
3977
3978
def trace_init ( model , template , params \\ Axon.ModelState . empty ( ) , opts \\ [ ] ) do
3978
3979
{ init_fn , _ } = build ( model , opts )
3979
- Nx.Defn . jit ( init_fn , compiler: Axon.Defn ) . ( template , params )
3980
+ Nx.Defn . jit ( init_fn , compiler: Axon.Defn ) . ( template , Axon.ModelState . new ( params ) )
3980
3981
end
3981
3982
3982
3983
@ doc """
@@ -4001,7 +4002,7 @@ defmodule Axon do
4001
4002
@ doc type: :debug
4002
4003
def trace_forward ( model , inputs , params , opts \\ [ ] ) when is_list ( opts ) do
4003
4004
{ _ , forward_fun } = build ( model , opts )
4004
- Nx.Defn . jit ( forward_fun , compiler: Axon.Defn ) . ( params , inputs )
4005
+ Nx.Defn . jit ( forward_fun , compiler: Axon.Defn ) . ( Axon.ModelState . new ( params ) , inputs )
4005
4006
end
4006
4007
4007
4008
@ doc """
@@ -4034,17 +4035,19 @@ defmodule Axon do
4034
4035
end )
4035
4036
end
4036
4037
4037
- % { prediction: outputs } = Nx.Defn . jit ( forward_fn , compiler: Axon.Defn ) . ( params , inputs )
4038
+ % { prediction: outputs } =
4039
+ Nx.Defn . jit ( forward_fn , compiler: Axon.Defn ) . ( Axon.ModelState . new ( params ) , inputs )
4040
+
4038
4041
inputs = [ params , inputs , outputs ]
4039
4042
4040
4043
apply ( Nx.Defn . jit ( backward_fn , compiler: Axon.Defn ) , inputs )
4041
4044
end
4042
4045
4043
4046
@ doc false
4044
4047
@ deprecated "Use Axon.build/2 instead"
4045
- def init ( model , template , params \\ % { } , opts \\ [ ] ) when is_list ( opts ) do
4048
+ def init ( model , template , params \\ Axon.ModelState . empty ( ) , opts \\ [ ] ) when is_list ( opts ) do
4046
4049
{ init_fn , _predict_fn } = build ( model , opts )
4047
- init_fn . ( template , params )
4050
+ init_fn . ( template , Axon.ModelState . new ( params ) )
4048
4051
end
4049
4052
4050
4053
@ doc """
@@ -4069,7 +4072,7 @@ defmodule Axon do
4069
4072
@ doc type: :model
4070
4073
def predict ( % Axon { } = model , params , input , opts \\ [ ] ) when is_list ( opts ) do
4071
4074
{ _init_fn , predict_fn } = build ( model , opts )
4072
- predict_fn . ( params , input )
4075
+ predict_fn . ( Axon.ModelState . new ( params ) , input )
4073
4076
end
4074
4077
4075
4078
## Inspection
0 commit comments