Skip to content

Commit 0473cf9

Browse files
committed
Merge metadata in compiler, resolves #567
1 parent 5f9e7bc commit 0473cf9

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

lib/axon/compiler.ex

+3
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,9 @@ defmodule Axon.Compiler do
11281128
out = Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name})
11291129
%{stateful | output: out}
11301130

1131+
%Nx.Tensor{data: %{op: :metadata, args: [arg, metadata]} = expr} = out ->
1132+
%{out | data: %{expr | args: [arg, Map.put(metadata, :axon_layer, op_name)]}}
1133+
11311134
%Nx.Tensor{} = out ->
11321135
Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name})
11331136

test/axon/integration_test.exs

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ defmodule Axon.IntegrationTest do
3232
model_state =
3333
model
3434
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
35-
|> Axon.Loop.run(data, Axon.ModelState.empty(), iterations: 100, epochs: 10)
35+
|> Axon.Loop.run(data, Axon.ModelState.empty(), iterations: 100, epochs: 20)
3636

3737
eval_results =
3838
model

0 commit comments

Comments
 (0)