2121
2222TEST_2D_CASE = []
2323for sample_method in ["uniform" , "logit-normal" ]:
24- TEST_2D_CASE .append ([{"sample_method" : sample_method , "use_timestep_transform" : False }, (2 , 6 , 16 , 16 ), (2 , 6 , 16 , 16 )])
24+ TEST_2D_CASE .append (
25+ [{"sample_method" : sample_method , "use_timestep_transform" : False }, (2 , 6 , 16 , 16 ), (2 , 6 , 16 , 16 )]
26+ )
2527
2628for sample_method in ["uniform" , "logit-normal" ]:
27- TEST_2D_CASE .append ([{"sample_method" : sample_method , "use_timestep_transform" : True , "spatial_dim" : 2 }, (2 , 6 , 16 , 16 ), (2 , 6 , 16 , 16 )])
29+ TEST_2D_CASE .append (
30+ [
31+ {"sample_method" : sample_method , "use_timestep_transform" : True , "spatial_dim" : 2 },
32+ (2 , 6 , 16 , 16 ),
33+ (2 , 6 , 16 , 16 ),
34+ ]
35+ )
2836
2937
3038TEST_3D_CASE = []
3139for sample_method in ["uniform" , "logit-normal" ]:
32- TEST_3D_CASE .append ([{"sample_method" : sample_method , "use_timestep_transform" : False }, (2 , 6 , 16 , 16 , 16 ), (2 , 6 , 16 , 16 , 16 )])
40+ TEST_3D_CASE .append (
41+ [{"sample_method" : sample_method , "use_timestep_transform" : False }, (2 , 6 , 16 , 16 , 16 ), (2 , 6 , 16 , 16 , 16 )]
42+ )
3343
3444for sample_method in ["uniform" , "logit-normal" ]:
35- TEST_3D_CASE .append ([{"sample_method" : sample_method , "use_timestep_transform" : True , "spatial_dim" : 3 }, (2 , 6 , 16 , 16 , 16 ), (2 , 6 , 16 , 16 , 16 )])
45+ TEST_3D_CASE .append (
46+ [
47+ {"sample_method" : sample_method , "use_timestep_transform" : True , "spatial_dim" : 3 },
48+ (2 , 6 , 16 , 16 , 16 ),
49+ (2 , 6 , 16 , 16 , 16 ),
50+ ]
51+ )
3652
3753TEST_CASES = TEST_2D_CASE + TEST_3D_CASE
3854
@@ -54,35 +70,35 @@ def test_add_noise(self, input_param, input_shape, expected_shape):
5470
5571 @parameterized .expand (TEST_CASES )
5672 def test_step_shape (self , input_param , input_shape , expected_shape ):
57- scheduler = RFlowScheduler (** input_param )
73+ scheduler = RFlowScheduler (** input_param )
5874 model_output = torch .randn (input_shape )
5975 sample = torch .randn (input_shape )
60- scheduler .set_timesteps (num_inference_steps = 100 , input_img_size_numel = torch .numel (sample [0 ,0 , ...]))
76+ scheduler .set_timesteps (num_inference_steps = 100 , input_img_size_numel = torch .numel (sample [0 , 0 , ...]))
6177 output_step = scheduler .step (model_output = model_output , timestep = 500 , sample = sample )
6278 self .assertEqual (output_step [0 ].shape , expected_shape )
6379 self .assertEqual (output_step [1 ].shape , expected_shape )
6480
6581 @parameterized .expand (TEST_FULl_LOOP )
6682 def test_full_timestep_loop (self , input_param , input_shape , expected_output ):
67- scheduler = RFlowScheduler (** input_param )
83+ scheduler = RFlowScheduler (** input_param )
6884 torch .manual_seed (42 )
6985 model_output = torch .randn (input_shape )
7086 sample = torch .randn (input_shape )
71- scheduler .set_timesteps (50 , input_img_size_numel = torch .numel (sample [0 ,0 , ...]))
87+ scheduler .set_timesteps (50 , input_img_size_numel = torch .numel (sample [0 , 0 , ...]))
7288 for t in range (50 ):
7389 sample , _ = scheduler .step (model_output = model_output , timestep = t , sample = sample )
7490 assert_allclose (sample , expected_output , rtol = 1e-3 , atol = 1e-3 )
7591
7692 def test_set_timesteps (self ):
7793 scheduler = RFlowScheduler (num_train_timesteps = 1000 )
78- scheduler .set_timesteps (num_inference_steps = 100 , input_img_size_numel = 16 * 16 * 16 )
94+ scheduler .set_timesteps (num_inference_steps = 100 , input_img_size_numel = 16 * 16 * 16 )
7995 self .assertEqual (scheduler .num_inference_steps , 100 )
8096 self .assertEqual (len (scheduler .timesteps ), 100 )
8197
8298 def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps (self ):
8399 scheduler = RFlowScheduler (num_train_timesteps = 1000 )
84100 with self .assertRaises (ValueError ):
85- scheduler .set_timesteps (num_inference_steps = 2000 , input_img_size_numel = 16 * 16 * 16 )
101+ scheduler .set_timesteps (num_inference_steps = 2000 , input_img_size_numel = 16 * 16 * 16 )
86102
87103
88104if __name__ == "__main__" :
0 commit comments