Skip to content

Commit 9fce600

Browse files
authored
Add rewrite_nodes function (#589)
1 parent c4d33e5 commit 9fce600

File tree

2 files changed

+190
-5
lines changed

2 files changed

+190
-5
lines changed

lib/axon.ex

+72-5
Original file line numberDiff line numberDiff line change
@@ -3731,16 +3731,15 @@ defmodule Axon do
37313731
relu layers with tanh layers:
37323732
37333733
new_model = Axon.map_nodes(model, fn
3734-
%Axon{op: :relu} = graph ->
3735-
# Get nodes immediate parent
3736-
parent = Axon.get_parent(graph)
3737-
# Replace node with a tanh
3738-
Axon.tanh(parent)
3734+
%Axon.Node{op: :relu} = axon_node ->
3735+
%{axon_node | op: :tanh}
37393736
37403737
graph ->
37413738
graph
37423739
end)
37433740
3741+
For more complex graph rewriting and manipulation cases, see
3742+
`Axon.rewrite_nodes/2`.
37443743
"""
37453744
@doc type: :graph
37463745
def map_nodes(%Axon{output: id, nodes: nodes} = axon, fun) when is_function(fun, 1) do
@@ -3779,6 +3778,74 @@ defmodule Axon do
37793778
Enum.reduce(inorder_nodes, acc, fun)
37803779
end
37813780

3781+
@doc """
3782+
Rewrite and manipulate nodes in the Axon execution graph.
3783+
3784+
Axon models are represented as a graph of nodes. Working on these nodes
3785+
directly can be difficult and lead to disconnected and invalid graphs.
3786+
In some cases, you simply want to rewrite patterns. This function takes
3787+
an Axon model and traverses the nodes, applying the rewrite `fun` on each
3788+
node to rewrite some or all of the nodes in the Axon model.
3789+
3790+
The rewrite function is an arity-1 function which takes the current Axon node
3791+
as input and returns a function that replaces or rewrites the given node.
3792+
For example, you can define a simple rewriter which replaces the `:relu`
3793+
layers with `:tanh` layers:
3794+
3795+
tanh_rewriter = fn [%Axon{} = x], _output ->
3796+
Axon.relu(x)
3797+
end
3798+
3799+
Axon.rewrite_nodes(model, fn
3800+
%Axon.Node{op: :relu} -> tanh_rewriter
3801+
_ -> :skip
3802+
end)
3803+
3804+
Notice that the rewriter receives all of the original graph inputs *as well as*
3805+
the original graph outputs. This makes certain transformations which may rely
3806+
on both the input and output, such as LoRA, much easier to perform.
3807+
"""
3808+
@doc type: :graph
3809+
def rewrite_nodes(%Axon{output: id, nodes: nodes}, fun) when is_function(fun, 1) do
3810+
{inorder_nodes, _} = traverse_nodes(id, nodes, [], MapSet.new())
3811+
3812+
updated_nodes =
3813+
Enum.reduce(inorder_nodes, nodes, fn
3814+
%{id: original_id, parent: parents} = current_node, nodes ->
3815+
rewriter = fun.(current_node)
3816+
3817+
case rewriter do
3818+
:skip ->
3819+
nodes
3820+
3821+
rewriter when is_function(rewriter, 2) ->
3822+
input_axons = Enum.map(parents, &%Axon{output: &1, nodes: nodes})
3823+
%Axon{output: swapped_id} = placeholder_output = Axon.input("placeholder_output")
3824+
3825+
%Axon{output: new_node_id, nodes: updated_nodes} =
3826+
rewriter.(input_axons, placeholder_output)
3827+
3828+
# now we have to swap the IDs for the rewritten model so that
3829+
# anything that references this node takes the new, rewritten form
3830+
# as an input properly
3831+
original_node = %{updated_nodes[original_id] | id: swapped_id}
3832+
updated_node = %{updated_nodes[new_node_id] | id: original_id}
3833+
3834+
updated_nodes
3835+
|> Map.replace(swapped_id, original_node)
3836+
|> Map.replace(original_id, updated_node)
3837+
end
3838+
end)
3839+
3840+
# if we removed any nodes (like by just using the input instead)
3841+
# then technically we will have extra nodes in the graph, so we
3842+
# can prune them by traversing once again
3843+
{pruned_nodes, _} = traverse_nodes(id, updated_nodes, [], MapSet.new())
3844+
pruned_nodes = Map.new(pruned_nodes, fn %{id: id} = axon_node -> {id, axon_node} end)
3845+
3846+
%Axon{output: id, nodes: pruned_nodes}
3847+
end
3848+
37823849
defp traverse_nodes(id, nodes, acc, visited) do
37833850
if MapSet.member?(visited, id) do
37843851
{acc, visited}

test/axon/compiler_test.exs

+118
Original file line numberDiff line numberDiff line change
@@ -5675,4 +5675,122 @@ defmodule CompilerTest do
56755675
assert out =~ "bar:"
56765676
end
56775677
end
5678+
5679+
describe "graph manipulation" do
5680+
test "rewrite_nodes does nothing if all rewrites are skip" do
5681+
model =
5682+
Axon.input("x")
5683+
|> Axon.dense(10, activation: :relu)
5684+
5685+
model = Axon.rewrite_nodes(model, fn _ -> :skip end)
5686+
5687+
{init_fn, predict_fn} = Axon.build(model)
5688+
input = Nx.broadcast(1, {1, 10})
5689+
5690+
%ModelState{data: %{"dense_0" => %{"kernel" => k, "bias" => b}}} =
5691+
model_state = init_fn.(input, ModelState.empty())
5692+
5693+
assert_equal(
5694+
predict_fn.(model_state, input),
5695+
Axon.Activations.relu(Axon.Layers.dense(input, k, b))
5696+
)
5697+
end
5698+
5699+
test "rewrite_nodes applies simple rewriters" do
5700+
relu_rewriter = fn [%Axon{} = x], _ ->
5701+
Axon.tanh(x)
5702+
end
5703+
5704+
model =
5705+
Axon.input("x")
5706+
|> Axon.dense(10, activation: :relu)
5707+
5708+
model =
5709+
Axon.rewrite_nodes(model, fn
5710+
%Axon.Node{op: :relu} -> relu_rewriter
5711+
_ -> :skip
5712+
end)
5713+
5714+
{init_fn, predict_fn} = Axon.build(model)
5715+
input = Nx.broadcast(1, {1, 10})
5716+
5717+
%ModelState{data: %{"dense_0" => %{"kernel" => k, "bias" => b}}} =
5718+
model_state = init_fn.(input, ModelState.empty())
5719+
5720+
assert_equal(
5721+
predict_fn.(model_state, input),
5722+
Axon.Activations.tanh(Axon.Layers.dense(input, k, b))
5723+
)
5724+
end
5725+
5726+
test "rewrite_nodes applies residual rewriter" do
5727+
residual_rewriter = fn [%Axon{} = x], %Axon{} = out ->
5728+
Axon.add(x, out)
5729+
end
5730+
5731+
model =
5732+
Axon.input("x")
5733+
|> Axon.dense(10, activation: :relu)
5734+
5735+
model =
5736+
Axon.rewrite_nodes(model, fn
5737+
%Axon.Node{op: :dense} -> residual_rewriter
5738+
_ -> :skip
5739+
end)
5740+
5741+
{init_fn, predict_fn} = Axon.build(model)
5742+
input = Nx.broadcast(1, {1, 10})
5743+
5744+
%ModelState{data: %{"dense_0" => %{"kernel" => k, "bias" => b}}} =
5745+
model_state = init_fn.(input, ModelState.empty())
5746+
5747+
real_fn = fn input, k, b ->
5748+
out = Nx.add(Axon.Layers.dense(input, k, b), input)
5749+
Axon.Activations.relu(out)
5750+
end
5751+
5752+
assert_equal(predict_fn.(model_state, input), real_fn.(input, k, b))
5753+
end
5754+
5755+
test "rewrite_nodes properly removes layers" do
5756+
remove_relu_rewriter = fn [%Axon{} = x], _out ->
5757+
x
5758+
end
5759+
5760+
input = Axon.input("x")
5761+
relu_tanh_input = Axon.tanh(Axon.relu(input))
5762+
5763+
model =
5764+
input
5765+
|> Axon.relu()
5766+
|> Axon.tanh()
5767+
|> Axon.relu()
5768+
|> Axon.tanh()
5769+
|> Axon.tanh()
5770+
|> Axon.relu()
5771+
|> Axon.relu()
5772+
|> Axon.add(relu_tanh_input)
5773+
5774+
model =
5775+
Axon.rewrite_nodes(model, fn
5776+
%Axon.Node{op: :relu} -> remove_relu_rewriter
5777+
_ -> :skip
5778+
end)
5779+
5780+
{_, predict_fn} = Axon.build(model)
5781+
input = Nx.broadcast(1, {1, 10})
5782+
5783+
real_fn = fn input ->
5784+
tanh_input = Axon.Activations.tanh(input)
5785+
5786+
input
5787+
|> Axon.Activations.tanh()
5788+
|> Axon.Activations.tanh()
5789+
|> Axon.Activations.tanh()
5790+
|> Nx.add(tanh_input)
5791+
end
5792+
5793+
assert_equal(predict_fn.(ModelState.empty(), input), real_fn.(input))
5794+
end
5795+
end
56785796
end

0 commit comments

Comments
 (0)