@@ -36,6 +36,16 @@ defmodule Axon.ModelState do
36
36
end )
37
37
end
38
38
39
+ @ doc """
40
+ Merges 2 states with function.
41
+ """
42
+ # TODO: Don't assume these have the same shapes
43
+ def merge ( % ModelState { } = lhs , % ModelState { data: rhs_data } , fun ) when is_function ( fun , 3 ) do
44
+ update_in ( lhs , [ Access . key! ( :data ) ] , fn data ->
45
+ tree_merge ( data , rhs_data , fun )
46
+ end )
47
+ end
48
+
39
49
# TODO: Mask syntax with strings?
40
50
41
51
@ doc """
@@ -259,4 +269,59 @@ defmodule Axon.ModelState do
259
269
end
260
270
end )
261
271
end
272
+
273
+ defimpl Inspect do
274
+ import Inspect.Algebra
275
+
276
+ def inspect ( % Axon.ModelState { data: params } = model_state , opts ) do
277
+ { total_parameter_count , total_parameter_size } = get_param_info ( params )
278
+
279
+ { trainable_parameter_count , trainable_parameter_size } =
280
+ get_param_info ( Axon.ModelState . trainable_parameters ( model_state ) )
281
+
282
+ { trainable_state_count , trainable_state_size } =
283
+ get_param_info ( Axon.ModelState . trainable_state ( model_state ) )
284
+
285
+ inner =
286
+ concat ( [
287
+ line ( ) ,
288
+ "Parameters: #{ total_parameter_count } (#{ helpful_size ( total_parameter_size ) } )" ,
289
+ line ( ) ,
290
+ "Trainable Parameters: #{ trainable_parameter_count } (#{ helpful_size ( trainable_parameter_size ) } )" ,
291
+ line ( ) ,
292
+ "Trainable State: #{ trainable_state_count } , (#{ helpful_size ( trainable_state_size ) } )"
293
+ ] )
294
+
295
+ force_unfit (
296
+ concat ( [
297
+ color ( "#Axon.ModelState<" , :map , opts ) ,
298
+ nest ( inner , 2 ) ,
299
+ line ( ) ,
300
+ color ( ">" , :map , opts )
301
+ ] )
302
+ )
303
+ end
304
+
305
+ defp get_param_info ( params ) do
306
+ Enum . reduce ( params , { 0 , 0 } , fn
307
+ { _ , % Nx.Tensor { } = tensor } , { count , size } ->
308
+ { count + Nx . size ( tensor ) , size + Nx . byte_size ( tensor ) }
309
+
310
+ { _ , map } , { count , size } ->
311
+ { inner_count , inner_size } = get_param_info ( map )
312
+ { count + inner_count , size + inner_size }
313
+ end )
314
+ end
315
+
316
+ defp helpful_size ( n ) when n < 1_000 , do: "#{ n } B"
317
+
318
+ defp helpful_size ( n ) when n >= 1_000 and n < 1_000_000 ,
319
+ do: "#{ :io_lib . format ( ~c" ~.2f KB" , [ n / 1_000 ] ) } "
320
+
321
+ defp helpful_size ( n ) when n >= 1_000_000 and n < 1_000_000_000 ,
322
+ do: "#{ :io_lib . format ( ~c" ~.2f MB" , [ n / 1_000_000 ] ) } "
323
+
324
+ defp helpful_size ( n ) when n >= 1_000_000_000 and n < 1_000_000_000_000 ,
325
+ do: "#{ :io_lib . format ( ~c" ~.2f GB" , [ n / 1_000_000_000 ] ) } "
326
+ end
262
327
end
0 commit comments