@@ -352,7 +352,7 @@ defmodule Axon.Compiler do
352
352
end
353
353
354
354
defp recur_model_funs (
355
- % Axon.Node { id: id , op: :constant , opts: [ value: tensor ] , policy: % { output: output } } ,
355
+ % Axon.Node { id: id , op: :constant , opts: [ value: tensor ] , policy: policy } ,
356
356
_nodes ,
357
357
{ cache , op_counts , block_cache } ,
358
358
_
@@ -361,7 +361,7 @@ defmodule Axon.Compiler do
361
361
tensor = Nx . backend_copy ( tensor , Nx.BinaryBackend )
362
362
363
363
predict_fun = fn _params , _inputs , state , _cache , result_cache , _fn_stacktrace ->
364
- out = safe_as_type ( tensor , output )
364
+ out = safe_policy_cast ( tensor , policy , : output)
365
365
{ out , { state , result_cache } }
366
366
end
367
367
@@ -841,7 +841,7 @@ defmodule Axon.Compiler do
841
841
name ,
842
842
args ,
843
843
opts ,
844
- % { output: output , compute: compute } ,
844
+ policy ,
845
845
layer_params ,
846
846
hooks ,
847
847
mode ,
@@ -870,7 +870,7 @@ defmodule Axon.Compiler do
870
870
871
871
layer_input =
872
872
layer_input
873
- |> safe_as_type ( compute )
873
+ |> safe_policy_cast ( policy , : compute)
874
874
|> apply_hooks ( :pre_forward , mode , hooks )
875
875
876
876
{ layer_input , { state , result_cache , none? } }
@@ -889,7 +889,7 @@ defmodule Axon.Compiler do
889
889
890
890
cond do
891
891
param != nil ->
892
- safe_as_type ( maybe_freeze ( param , frz ) , compute )
892
+ safe_policy_cast ( maybe_freeze ( param , frz ) , policy , : compute)
893
893
894
894
true ->
895
895
raise ArgumentError ,
@@ -939,7 +939,7 @@ defmodule Axon.Compiler do
939
939
out
940
940
|> apply_hooks ( :forward , mode , hooks )
941
941
|> apply_hooks ( :backward , mode , hooks )
942
- |> safe_as_type ( output )
942
+ |> safe_policy_cast ( policy , : output)
943
943
944
944
new_state = Map . put ( state , name , out_state )
945
945
{ new_out , new_state }
@@ -949,7 +949,7 @@ defmodule Axon.Compiler do
949
949
out
950
950
|> apply_hooks ( :forward , mode , hooks )
951
951
|> apply_hooks ( :backward , mode , hooks )
952
- |> safe_as_type ( output )
952
+ |> safe_policy_cast ( policy , : output)
953
953
954
954
{ new_out , state }
955
955
end
@@ -1130,26 +1130,13 @@ defmodule Axon.Compiler do
1130
1130
end )
1131
1131
end
1132
1132
1133
- defp safe_as_type ( container_or_tensor , type ) do
1133
+ defp safe_policy_cast ( container_or_tensor , policy , variable_type ) do
1134
1134
case container_or_tensor do
1135
1135
% Axon.None { } = none ->
1136
1136
none
1137
1137
1138
- % Nx.Tensor { } = tensor ->
1139
- if not Nx.Type . integer? ( Nx . type ( tensor ) ) and not Nx.Type . integer? ( type ) do
1140
- Nx . as_type ( tensor , type )
1141
- else
1142
- tensor
1143
- end
1144
-
1145
- container ->
1146
- deep_new ( container , fn tensor ->
1147
- if not Nx.Type . integer? ( Nx . type ( tensor ) ) and not Nx.Type . integer? ( type ) do
1148
- Nx . as_type ( tensor , type )
1149
- else
1150
- tensor
1151
- end
1152
- end )
1138
+ container_or_tensor ->
1139
+ Axon.MixedPrecision . cast ( policy , container_or_tensor , variable_type )
1153
1140
end
1154
1141
end
1155
1142
0 commit comments