@@ -3731,16 +3731,15 @@ defmodule Axon do
3731
3731
relu layers with tanh layers:
3732
3732
3733
3733
new_model = Axon.map_nodes(model, fn
3734
- %Axon{op: :relu} = graph ->
3735
- # Get nodes immediate parent
3736
- parent = Axon.get_parent(graph)
3737
- # Replace node with a tanh
3738
- Axon.tanh(parent)
3734
+ %Axon.Node{op: :relu} = axon_node ->
3735
+ %{axon_node | op: :tanh}
3739
3736
3740
3737
graph ->
3741
3738
graph
3742
3739
end)
3743
3740
3741
+ For more complex graph rewriting and manipulation cases, see
3742
+ `Axon.rewrite_nodes/2`.
3744
3743
"""
3745
3744
@ doc type: :graph
3746
3745
def map_nodes ( % Axon { output: id , nodes: nodes } = axon , fun ) when is_function ( fun , 1 ) do
@@ -3779,6 +3778,74 @@ defmodule Axon do
3779
3778
Enum . reduce ( inorder_nodes , acc , fun )
3780
3779
end
3781
3780
3781
+ @ doc """
3782
+ Rewrite and manipulate nodes in the Axon execution graph.
3783
+
3784
+ Axon models are represented as a graph of nodes. Working on these nodes
3785
+ directly can be difficult and lead to disconnected and invalid graphs.
3786
+ In some cases, you simply want to rewrite patterns. This function takes
3787
+ an Axon model and traverses the nodes, applying the rewrite `fun` on each
3788
+ node to rewrite some or all of the nodes in the Axon model.
3789
+
3790
+ The rewrite function is an arity-1 function which takes the current Axon node
3791
+ as input and returns a function that replaces or rewrites the given node.
3792
+ For example, you can define a simple rewriter which replaces the `:relu`
3793
+ layers with `:tanh` layers:
3794
+
3795
+ tanh_rewriter = fn [%Axon{} = x], _output ->
3796
+ Axon.relu(x)
3797
+ end
3798
+
3799
+ Axon.rewrite_nodes(model, fn
3800
+ %Axon.Node{op: :relu} -> tanh_rewriter
3801
+ _ -> :skip
3802
+ end)
3803
+
3804
+ Notice that the rewriter receives all of the original graph inputs *as well as*
3805
+ the original graph outputs. This makes certain transformations which may rely
3806
+ on both the input and output, such as LoRA, much easier to perform.
3807
+ """
3808
+ @ doc type: :graph
3809
+ def rewrite_nodes ( % Axon { output: id , nodes: nodes } , fun ) when is_function ( fun , 1 ) do
3810
+ { inorder_nodes , _ } = traverse_nodes ( id , nodes , [ ] , MapSet . new ( ) )
3811
+
3812
+ updated_nodes =
3813
+ Enum . reduce ( inorder_nodes , nodes , fn
3814
+ % { id: original_id , parent: parents } = current_node , nodes ->
3815
+ rewriter = fun . ( current_node )
3816
+
3817
+ case rewriter do
3818
+ :skip ->
3819
+ nodes
3820
+
3821
+ rewriter when is_function ( rewriter , 2 ) ->
3822
+ input_axons = Enum . map ( parents , & % Axon { output: & 1 , nodes: nodes } )
3823
+ % Axon { output: swapped_id } = placeholder_output = Axon . input ( "placeholder_output" )
3824
+
3825
+ % Axon { output: new_node_id , nodes: updated_nodes } =
3826
+ rewriter . ( input_axons , placeholder_output )
3827
+
3828
+ # now we have to swap the IDs for the rewritten model so that
3829
+ # anything that references this node takes the new, rewritten form
3830
+ # as an input properly
3831
+ original_node = % { updated_nodes [ original_id ] | id: swapped_id }
3832
+ updated_node = % { updated_nodes [ new_node_id ] | id: original_id }
3833
+
3834
+ updated_nodes
3835
+ |> Map . replace ( swapped_id , original_node )
3836
+ |> Map . replace ( original_id , updated_node )
3837
+ end
3838
+ end )
3839
+
3840
+ # if we removed any nodes (like by just using the input instead)
3841
+ # then technically we will have extra nodes in the graph, so we
3842
+ # can prune them by traversing once again
3843
+ { pruned_nodes , _ } = traverse_nodes ( id , updated_nodes , [ ] , MapSet . new ( ) )
3844
+ pruned_nodes = Map . new ( pruned_nodes , fn % { id: id } = axon_node -> { id , axon_node } end )
3845
+
3846
+ % Axon { output: id , nodes: pruned_nodes }
3847
+ end
3848
+
3782
3849
defp traverse_nodes ( id , nodes , acc , visited ) do
3783
3850
if MapSet . member? ( visited , id ) do
3784
3851
{ acc , visited }
0 commit comments