Skip to content

Commit efd4c1f

Browse files
committed
Add inspect protocol for model state
1 parent f48dcd1 commit efd4c1f

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

lib/axon.ex

+21
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,27 @@ defmodule Axon do
548548
layer(:optional, [x], name: opts[:name], meta: opts[:meta], op_name: :optional)
549549
end
550550

551+
@doc """
552+
Implements an or else (e.g. an Elixir ||)
553+
"""
554+
@doc type: :special
555+
def or_else(%Axon{} = a, %Axon{} = b, opts \\ []) do
556+
opts = Keyword.validate!(opts, [:name, :meta])
557+
558+
Axon.layer(
559+
fn x, y, _ ->
560+
case x do
561+
%Axon.None{} -> y
562+
_ -> x
563+
end
564+
end,
565+
[a, b],
566+
op_name: :or_else,
567+
name: opts[:name],
568+
meta: opts[:meta]
569+
)
570+
end
571+
551572
@doc """
552573
Adds a constant layer to the network.
553574

lib/axon/model_state.ex

+65
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ defmodule Axon.ModelState do
3636
end)
3737
end
3838

39+
@doc """
40+
Merges 2 states with function.
41+
"""
42+
# TODO: Don't assume these have the same shapes
43+
def merge(%ModelState{} = lhs, %ModelState{data: rhs_data}, fun) when is_function(fun, 3) do
44+
update_in(lhs, [Access.key!(:data)], fn data ->
45+
tree_merge(data, rhs_data, fun)
46+
end)
47+
end
48+
3949
# TODO: Mask syntax with strings?
4050

4151
@doc """
@@ -259,4 +269,59 @@ defmodule Axon.ModelState do
259269
end
260270
end)
261271
end
272+
273+
defimpl Inspect do
274+
import Inspect.Algebra
275+
276+
def inspect(%Axon.ModelState{data: params} = model_state, opts) do
277+
{total_parameter_count, total_parameter_size} = get_param_info(params)
278+
279+
{trainable_parameter_count, trainable_parameter_size} =
280+
get_param_info(Axon.ModelState.trainable_parameters(model_state))
281+
282+
{trainable_state_count, trainable_state_size} =
283+
get_param_info(Axon.ModelState.trainable_state(model_state))
284+
285+
inner =
286+
concat([
287+
line(),
288+
"Parameters: #{total_parameter_count} (#{helpful_size(total_parameter_size)})",
289+
line(),
290+
"Trainable Parameters: #{trainable_parameter_count} (#{helpful_size(trainable_parameter_size)})",
291+
line(),
292+
"Trainable State: #{trainable_state_count}, (#{helpful_size(trainable_state_size)})"
293+
])
294+
295+
force_unfit(
296+
concat([
297+
color("#Axon.ModelState<", :map, opts),
298+
nest(inner, 2),
299+
line(),
300+
color(">", :map, opts)
301+
])
302+
)
303+
end
304+
305+
defp get_param_info(params) do
306+
Enum.reduce(params, {0, 0}, fn
307+
{_, %Nx.Tensor{} = tensor}, {count, size} ->
308+
{count + Nx.size(tensor), size + Nx.byte_size(tensor)}
309+
310+
{_, map}, {count, size} ->
311+
{inner_count, inner_size} = get_param_info(map)
312+
{count + inner_count, size + inner_size}
313+
end)
314+
end
315+
316+
defp helpful_size(n) when n < 1_000, do: "#{n} B"
317+
318+
defp helpful_size(n) when n >= 1_000 and n < 1_000_000,
319+
do: "#{:io_lib.format(~c"~.2f KB", [n / 1_000])}"
320+
321+
defp helpful_size(n) when n >= 1_000_000 and n < 1_000_000_000,
322+
do: "#{:io_lib.format(~c"~.2f MB", [n / 1_000_000])}"
323+
324+
defp helpful_size(n) when n >= 1_000_000_000 and n < 1_000_000_000_000,
325+
do: "#{:io_lib.format(~c"~.2f GB", [n / 1_000_000_000])}"
326+
end
262327
end

0 commit comments

Comments
 (0)