Skip to content

Commit 054eb4c

Browse files
committed
Use model state everywhere as default
1 parent 8e0a6d9 commit 054eb4c

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

lib/axon.ex

+11-8
Original file line numberDiff line numberDiff line change
@@ -3943,9 +3943,10 @@ defmodule Axon do
39433943
It accepts the same options as `build/2`.
39443944
"""
39453945
@doc type: :model
3946-
def compile(model, template, init_params \\ %{}, opts \\ []) when is_list(opts) do
3946+
def compile(model, template, init_params \\ Axon.ModelState.empty(), opts \\ [])
3947+
when is_list(opts) do
39473948
{init_fn, predict_fn} = build(model, opts)
3948-
init_params = Nx.Defn.jit_apply(init_fn, [template, init_params], opts)
3949+
init_params = Nx.Defn.jit_apply(init_fn, [template, Axon.ModelState.new(init_params)], opts)
39493950
predict_compiled_fn = Nx.Defn.compile(predict_fn, [init_params, template], opts)
39503951
{init_params, predict_compiled_fn}
39513952
end
@@ -3976,7 +3977,7 @@ defmodule Axon do
39763977
@doc type: :debug
39773978
def trace_init(model, template, params \\ Axon.ModelState.empty(), opts \\ []) do
39783979
{init_fn, _} = build(model, opts)
3979-
Nx.Defn.jit(init_fn, compiler: Axon.Defn).(template, params)
3980+
Nx.Defn.jit(init_fn, compiler: Axon.Defn).(template, Axon.ModelState.new(params))
39803981
end
39813982

39823983
@doc """
@@ -4001,7 +4002,7 @@ defmodule Axon do
40014002
@doc type: :debug
40024003
def trace_forward(model, inputs, params, opts \\ []) when is_list(opts) do
40034004
{_, forward_fun} = build(model, opts)
4004-
Nx.Defn.jit(forward_fun, compiler: Axon.Defn).(params, inputs)
4005+
Nx.Defn.jit(forward_fun, compiler: Axon.Defn).(Axon.ModelState.new(params), inputs)
40054006
end
40064007

40074008
@doc """
@@ -4034,17 +4035,19 @@ defmodule Axon do
40344035
end)
40354036
end
40364037

4037-
%{prediction: outputs} = Nx.Defn.jit(forward_fn, compiler: Axon.Defn).(params, inputs)
4038+
%{prediction: outputs} =
4039+
Nx.Defn.jit(forward_fn, compiler: Axon.Defn).(Axon.ModelState.new(params), inputs)
4040+
40384041
inputs = [params, inputs, outputs]
40394042

40404043
apply(Nx.Defn.jit(backward_fn, compiler: Axon.Defn), inputs)
40414044
end
40424045

40434046
@doc false
40444047
@deprecated "Use Axon.build/2 instead"
4045-
def init(model, template, params \\ %{}, opts \\ []) when is_list(opts) do
4048+
def init(model, template, params \\ Axon.ModelState.empty(), opts \\ []) when is_list(opts) do
40464049
{init_fn, _predict_fn} = build(model, opts)
4047-
init_fn.(template, params)
4050+
init_fn.(template, Axon.ModelState.new(params))
40484051
end
40494052

40504053
@doc """
@@ -4069,7 +4072,7 @@ defmodule Axon do
40694072
@doc type: :model
40704073
def predict(%Axon{} = model, params, input, opts \\ []) when is_list(opts) do
40714074
{_init_fn, predict_fn} = build(model, opts)
4072-
predict_fn.(params, input)
4075+
predict_fn.(Axon.ModelState.new(params), input)
40734076
end
40744077

40754078
## Inspection

0 commit comments

Comments
 (0)