Skip to content

Commit acdc002

Browse files
Do not cast integers in in Axon.MixedPrecision.cast/2 (#562)
1 parent 2a2d165 commit acdc002

File tree

2 files changed

+28
-25
lines changed

2 files changed

+28
-25
lines changed

Diff for: lib/axon/compiler.ex

+10-23
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ defmodule Axon.Compiler do
352352
end
353353

354354
defp recur_model_funs(
355-
%Axon.Node{id: id, op: :constant, opts: [value: tensor], policy: %{output: output}},
355+
%Axon.Node{id: id, op: :constant, opts: [value: tensor], policy: policy},
356356
_nodes,
357357
{cache, op_counts, block_cache},
358358
_
@@ -361,7 +361,7 @@ defmodule Axon.Compiler do
361361
tensor = Nx.backend_copy(tensor, Nx.BinaryBackend)
362362

363363
predict_fun = fn _params, _inputs, state, _cache, result_cache, _fn_stacktrace ->
364-
out = safe_as_type(tensor, output)
364+
out = safe_policy_cast(tensor, policy, :output)
365365
{out, {state, result_cache}}
366366
end
367367

@@ -841,7 +841,7 @@ defmodule Axon.Compiler do
841841
name,
842842
args,
843843
opts,
844-
%{output: output, compute: compute},
844+
policy,
845845
layer_params,
846846
hooks,
847847
mode,
@@ -870,7 +870,7 @@ defmodule Axon.Compiler do
870870

871871
layer_input =
872872
layer_input
873-
|> safe_as_type(compute)
873+
|> safe_policy_cast(policy, :compute)
874874
|> apply_hooks(:pre_forward, mode, hooks)
875875

876876
{layer_input, {state, result_cache, none?}}
@@ -889,7 +889,7 @@ defmodule Axon.Compiler do
889889

890890
cond do
891891
param != nil ->
892-
safe_as_type(maybe_freeze(param, frz), compute)
892+
safe_policy_cast(maybe_freeze(param, frz), policy, :compute)
893893

894894
true ->
895895
raise ArgumentError,
@@ -939,7 +939,7 @@ defmodule Axon.Compiler do
939939
out
940940
|> apply_hooks(:forward, mode, hooks)
941941
|> apply_hooks(:backward, mode, hooks)
942-
|> safe_as_type(output)
942+
|> safe_policy_cast(policy, :output)
943943

944944
new_state = Map.put(state, name, out_state)
945945
{new_out, new_state}
@@ -949,7 +949,7 @@ defmodule Axon.Compiler do
949949
out
950950
|> apply_hooks(:forward, mode, hooks)
951951
|> apply_hooks(:backward, mode, hooks)
952-
|> safe_as_type(output)
952+
|> safe_policy_cast(policy, :output)
953953

954954
{new_out, state}
955955
end
@@ -1130,26 +1130,13 @@ defmodule Axon.Compiler do
11301130
end)
11311131
end
11321132

1133-
defp safe_as_type(container_or_tensor, type) do
1133+
defp safe_policy_cast(container_or_tensor, policy, variable_type) do
11341134
case container_or_tensor do
11351135
%Axon.None{} = none ->
11361136
none
11371137

1138-
%Nx.Tensor{} = tensor ->
1139-
if not Nx.Type.integer?(Nx.type(tensor)) and not Nx.Type.integer?(type) do
1140-
Nx.as_type(tensor, type)
1141-
else
1142-
tensor
1143-
end
1144-
1145-
container ->
1146-
deep_new(container, fn tensor ->
1147-
if not Nx.Type.integer?(Nx.type(tensor)) and not Nx.Type.integer?(type) do
1148-
Nx.as_type(tensor, type)
1149-
else
1150-
tensor
1151-
end
1152-
end)
1138+
container_or_tensor ->
1139+
Axon.MixedPrecision.cast(policy, container_or_tensor, variable_type)
11531140
end
11541141
end
11551142

Diff for: lib/axon/mixed_precision.ex

+18-2
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,26 @@ defmodule Axon.MixedPrecision do
146146
iex> value = Axon.MixedPrecision.cast(policy, value, :output)
147147
iex> Nx.type(value)
148148
{:bf, 16}
149+
150+
Note that integers are never promoted to floats:
151+
152+
iex> policy = Axon.MixedPrecision.create_policy(output: {:f, 16})
153+
iex> value = Nx.tensor([1, 2, 3], type: :s64)
154+
iex> value = Axon.MixedPrecision.cast(policy, value, :params)
155+
iex> Nx.type(value)
156+
{:s, 64}
157+
149158
"""
150159
def cast(%Policy{} = policy, tensor_or_container, variable_type)
151160
when variable_type in [:compute, :params, :output] do
152-
type = get_in(policy, [Access.key!(variable_type)])
153-
deep_new(tensor_or_container, fn x -> Nx.as_type(x, type) end)
161+
type = Map.fetch!(policy, variable_type)
162+
163+
deep_new(tensor_or_container, fn tensor ->
164+
if not Nx.Type.integer?(Nx.type(tensor)) and not Nx.Type.integer?(type) do
165+
Nx.as_type(tensor, type)
166+
else
167+
tensor
168+
end
169+
end)
154170
end
155171
end

0 commit comments

Comments
 (0)