1- # Add keras.utils for the random seed
1+
22import os
33
44import keras
1616
1717class RandomElasticDeformation3DTest (TestCase ):
1818 def test_layer_basics (self ):
19- # --- BEST PRACTICE: Add a seed for reproducibility ---
19+
2020 utils .set_random_seed (0 )
2121 layer = RandomElasticDeformation3D (
2222 grid_size = (4 , 4 , 4 ),
@@ -32,8 +32,6 @@ def test_layer_basics(self):
3232 self .assertEqual (label .dtype , output_label .dtype )
3333
3434 def test_serialization (self ):
35- # --- BEST PRACTICE: Add a seed for reproducibility ---
36- utils .set_random_seed (0 )
3735 layer = RandomElasticDeformation3D (
3836 grid_size = (3 , 3 , 3 ),
3937 alpha = 50.0 ,
@@ -42,34 +40,39 @@ def test_serialization(self):
4240 image_data = ops .ones ((2 , 16 , 16 , 16 , 3 ), dtype = "float32" )
4341 label_data = ops .ones ((2 , 16 , 16 , 16 , 1 ), dtype = "int32" )
4442 input_data = (image_data , label_data )
43+
4544 image_input = Input (shape = (16 , 16 , 16 , 3 ), dtype = "float32" )
4645 label_input = Input (shape = (16 , 16 , 16 , 1 ), dtype = "int32" )
4746 outputs = layer ((image_input , label_input ))
4847 model = Model (inputs = [image_input , label_input ], outputs = outputs )
48+
49+
50+ utils .set_random_seed (0 )
4951 original_output_image , original_output_label = model (input_data )
50- path = os .path .join (self .get_temp_dir (), "model.keras" )
5152
52- # --- FIX: Remove the deprecated save_format argument ---
53+ path = os . path . join ( self . get_temp_dir (), "model.keras" )
5354 model .save (path )
54-
5555 loaded_model = keras .models .load_model (
56- path ,
57- custom_objects = {
58- "RandomElasticDeformation3D" : RandomElasticDeformation3D
59- },
56+ path , custom_objects = {"RandomElasticDeformation3D" : RandomElasticDeformation3D }
6057 )
58+
59+
60+ utils .set_random_seed (0 )
6161 loaded_output_image , loaded_output_label = loaded_model (input_data )
62+
63+
6264 np .testing .assert_allclose (
6365 ops .convert_to_numpy (original_output_image ),
6466 ops .convert_to_numpy (loaded_output_image ),
67+ atol = 1e-6
6568 )
6669 np .testing .assert_array_equal (
6770 ops .convert_to_numpy (original_output_label ),
6871 ops .convert_to_numpy (loaded_output_label ),
6972 )
7073
7174 def test_label_values_are_preserved (self ):
72- # --- BEST PRACTICE: Add a seed for reproducibility ---
75+
7376 utils .set_random_seed (0 )
7477 image = ops .zeros (shape = (1 , 16 , 16 , 16 , 1 ), dtype = "float32" )
7578 label_arange = ops .arange (16 ** 3 , dtype = "int32" )
0 commit comments