Skip to content

Commit 7e0e593

Browse files
authoredMar 6, 2024
Add support for global layer options (#563)
1 parent acdc002 commit 7e0e593

File tree

4 files changed

+50
-5
lines changed

4 files changed

+50
-5
lines changed
 

‎lib/axon.ex

+12-2
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,14 @@ defmodule Axon do
301301
to inference function except:
302302
303303
* `:name` - layer name.
304+
304305
* `:op_name` - layer operation for inspection and building parameter map.
306+
305307
* `:mode` - if the layer should run only on `:inference` or `:train`. Defaults to `:both`
306308
309+
* `:global_options` - a list of global option names that this layer
310+
supports. Global options passed to `build/2` will be forwarded to
311+
the layer, as long as they are declared
307312
308313
Note this means your layer should not use these as input options,
309314
as they will always be dropped during inference compilation.
@@ -332,14 +337,15 @@ defmodule Axon do
332337
{mode, opts} = Keyword.pop(opts, :mode, :both)
333338
{name, opts} = Keyword.pop(opts, :name)
334339
{op_name, opts} = Keyword.pop(opts, :op_name, :custom)
340+
{global_options, opts} = Keyword.pop(opts, :global_options, [])
335341
name = name(op_name, name)
336342

337343
id = System.unique_integer([:positive, :monotonic])
338-
axon_node = make_node(id, op, name, op_name, mode, inputs, params, args, opts)
344+
axon_node = make_node(id, op, name, op_name, mode, inputs, params, args, opts, global_options)
339345
%Axon{output: id, nodes: Map.put(updated_nodes, id, axon_node)}
340346
end
341347

342-
defp make_node(id, op, name, op_name, mode, inputs, params, args, layer_opts) do
348+
defp make_node(id, op, name, op_name, mode, inputs, params, args, layer_opts, global_options) do
343349
{:current_stacktrace, [_process_info, _axon_layer | stacktrace]} =
344350
Process.info(self(), :current_stacktrace)
345351

@@ -354,6 +360,7 @@ defmodule Axon do
354360
policy: Axon.MixedPrecision.create_policy(),
355361
hooks: [],
356362
opts: layer_opts,
363+
global_options: global_options,
357364
op_name: op_name,
358365
stacktrace: stacktrace
359366
}
@@ -3651,6 +3658,9 @@ defmodule Axon do
36513658
to control differences in compilation at training or inference time.
36523659
Defaults to `:inference`
36533660
3661+
* `:global_layer_options` - a keyword list of options passed to
3662+
layers that accept said options
3663+
36543664
All other options are forwarded to the underlying JIT compiler.
36553665
"""
36563666
@doc type: :model

‎lib/axon/compiler.ex

+14-3
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ defmodule Axon.Compiler do
5050
debug? = Keyword.get(opts, :debug, false)
5151
mode = Keyword.get(opts, :mode, :inference)
5252
seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end)
53-
config = %{mode: mode, debug?: debug?}
53+
global_layer_options = Keyword.get(opts, :global_layer_options, [])
54+
config = %{mode: mode, debug?: debug?, global_layer_options: global_layer_options}
5455

5556
{time, {root_id, {cache, _op_counts, _block_cache}}} =
5657
:timer.tc(fn ->
@@ -718,14 +719,15 @@ defmodule Axon.Compiler do
718719
parameters: layer_params,
719720
args: args,
720721
opts: opts,
722+
global_options: global_options,
721723
policy: policy,
722724
hooks: hooks,
723725
op_name: op_name,
724726
stacktrace: stacktrace
725727
},
726728
nodes,
727729
cache_and_counts,
728-
%{mode: mode, debug?: debug?} = config
730+
%{mode: mode, debug?: debug?, global_layer_options: global_layer_options} = config
729731
)
730732
when (is_function(op) or is_atom(op)) and is_list(inputs) do
731733
# Traverse to accumulate cache and get parent_ids for
@@ -761,10 +763,12 @@ defmodule Axon.Compiler do
761763
name,
762764
args,
763765
opts,
766+
global_options,
764767
policy,
765768
layer_params,
766769
hooks,
767770
mode,
771+
global_layer_options,
768772
stacktrace
769773
)
770774

@@ -841,10 +845,12 @@ defmodule Axon.Compiler do
841845
name,
842846
args,
843847
opts,
848+
global_options,
844849
policy,
845850
layer_params,
846851
hooks,
847852
mode,
853+
global_layer_options,
848854
layer_stacktrace
849855
) do
850856
# Recurse graph inputs and invoke cache to get parent results,
@@ -914,7 +920,12 @@ defmodule Axon.Compiler do
914920

915921
# Compute arguments to be forwarded and ensure `:mode` is included
916922
# for inference/training behavior dependent functions
917-
args = Enum.reverse(tensor_inputs, [Keyword.put(opts, :mode, mode)])
923+
layer_opts =
924+
opts
925+
|> Keyword.merge(Keyword.take(global_layer_options, global_options))
926+
|> Keyword.put(:mode, mode)
927+
928+
args = Enum.reverse(tensor_inputs, [layer_opts])
918929

919930
# For built-in layers we always just apply the equivalent function
920931
# in Axon.Layers. The implication of this is that every function which

‎lib/axon/node.ex

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ defmodule Axon.Node do
1212
:policy,
1313
:hooks,
1414
:opts,
15+
:global_options,
1516
:op_name,
1617
:stacktrace
1718
]

‎test/axon/compiler_test.exs

+23
Original file line numberDiff line numberDiff line change
@@ -5837,4 +5837,27 @@ defmodule CompilerTest do
58375837
assert predict_fn.(params, x) == Nx.add(x, a)
58385838
end
58395839
end
5840+
5841+
describe "global layer options" do
5842+
test "global options are forwarded to the layer when declared" do
5843+
input = Axon.input("input")
5844+
5845+
model =
5846+
Axon.layer(
5847+
fn input, opts ->
5848+
assert Keyword.has_key?(opts, :option1)
5849+
refute Keyword.has_key?(opts, :option2)
5850+
input
5851+
end,
5852+
[input],
5853+
global_options: [:option1]
5854+
)
5855+
5856+
{_, predict_fn} = Axon.build(model, global_layer_options: [option1: true, option2: true])
5857+
5858+
params = %{}
5859+
input = random({1, 1}, type: {:f, 32})
5860+
predict_fn.(params, input)
5861+
end
5862+
end
58405863
end

0 commit comments

Comments
 (0)