Skip to content

Commit 19803a0

Browse files
committed
Add option to not raise if output is none, resolves #538
1 parent 252da8c commit 19803a0

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

lib/axon.ex

+1-1
Original file line numberDiff line numberDiff line change
@@ -3405,7 +3405,7 @@ defmodule Axon do
34053405
"""
34063406
@doc type: :graph
34073407
def get_output_shape(%Axon{} = axon, inputs, opts \\ []) do
3408-
{init_fn, forward_fn} = build(axon, opts)
3408+
{init_fn, forward_fn} = build(axon, opts ++ [raise_on_none: false])
34093409

34103410
out =
34113411
Nx.Defn.jit(

lib/axon/compiler.ex

+7-4
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ defmodule Axon.Compiler do
4848
@doc false
4949
def build(%Axon{output: id, nodes: nodes}, opts) do
5050
debug? = Keyword.get(opts, :debug, false)
51+
raise_on_none? = Keyword.get(opts, :raise_on_none, true)
5152
mode = Keyword.get(opts, :mode, :inference)
5253
seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end)
5354
global_layer_options = Keyword.get(opts, :global_layer_options, [])
@@ -105,10 +106,12 @@ defmodule Axon.Compiler do
105106
end
106107

107108
with %Axon.None{} <- result do
108-
raise ArgumentError,
109-
"the compiled model will always result in %Axon.None{}." <>
110-
" This most likely means you specified optional output and " <>
111-
" did not handle the case when it is missing"
109+
if raise_on_none? do
110+
raise ArgumentError,
111+
"the compiled model will always result in %Axon.None{}." <>
112+
" This most likely means you specified optional output and " <>
113+
" did not handle the case when it is missing"
114+
end
112115
end
113116

114117
result

test/axon_test.exs

+14
Original file line numberDiff line numberDiff line change
@@ -1076,5 +1076,19 @@ defmodule AxonTest do
10761076
assert shape = Axon.get_output_shape(model, Nx.template({1, 1}, :f32))
10771077
assert shape == {{1, 2}, {1, 2}}
10781078
end
1079+
1080+
test "doesn't raise on none output" do
1081+
values = Axon.input("values")
1082+
mask = Axon.input("mask", optional: true)
1083+
1084+
model =
1085+
values
1086+
|> Axon.dense(10)
1087+
|> Axon.multiply(mask)
1088+
|> Axon.dense(1)
1089+
|> Axon.sigmoid()
1090+
1091+
assert %Axon.None{} = Axon.get_output_shape(model, %{"values" => Nx.template({1, 1}, :f32)})
1092+
end
10791093
end
10801094
end

0 commit comments

Comments
 (0)